• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import functools
17import gc
18import weakref
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python import pywrap_tfe
24from tensorflow.python.distribute import mirrored_strategy
25from tensorflow.python.eager import backprop
26from tensorflow.python.eager import context
27from tensorflow.python.eager import def_function
28from tensorflow.python.eager import forwardprop
29from tensorflow.python.eager import forwardprop_util
30from tensorflow.python.eager import tape as tape_lib
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import test_util
35from tensorflow.python.module import module
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import custom_gradient
38from tensorflow.python.ops import gradient_checker_v2
39from tensorflow.python.ops import map_fn
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import nn_impl
42from tensorflow.python.ops import nn_ops
43from tensorflow.python.ops import random_ops
44from tensorflow.python.ops import variables
45from tensorflow.python.ops.parallel_for import control_flow_ops
46from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
47from tensorflow.python.platform import test
48from tensorflow.python.util import nest
49
50_X11_35_DERIVATIVES = [
51    1.1**3.5, 3.5 * 1.1**2.5, 3.5 * 2.5 * 1.1**1.5, 3.5 * 2.5 * 1.5 * 1.1**0.5
52]
53
54
55# TODO(allenl): Move this somewhere useful once forward gradients are stable.
56def _jvp(f, primals, tangents):
57  """Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
58  with forwardprop.ForwardAccumulator(primals, tangents) as acc:
59    primals_out = f(*primals)
60  return primals_out, acc.jvp(
61      primals_out, unconnected_gradients=UnconnectedGradients.ZERO)
62
63
64def _jacfwd(f, primals):
65  """Compute the jacobian of `f` at `primals` using forward-mode autodiff."""
66  jac_flat = []
67  flat_primals = nest.flatten(primals)
68  tangent_mask = [
69      array_ops.zeros_like(primal, dtype=primal.dtype)
70      for primal in flat_primals
71  ]
72  for primal_index, primal in enumerate(flat_primals):
73    primal_vector = array_ops.reshape(primal, [-1])
74    primal_vector_length = array_ops.size(primal_vector)
75    jac_columns = []
76    for element_index in math_ops.range(primal_vector_length):
77      mask = array_ops.one_hot(
78          element_index, primal_vector_length, dtype=primal.dtype)
79      tangent_mask[primal_index] = array_ops.reshape(mask,
80                                                     array_ops.shape(primal))
81      jac_columns.append(
82          nest.map_structure(
83              functools.partial(array_ops.reshape, shape=[-1]),
84              _jvp(f, primals, nest.pack_sequence_as(primals,
85                                                     tangent_mask))[1]))
86    jac_flat.append(array_ops.stack(jac_columns, axis=1))
87    tangent_mask[primal_index] = array_ops.zeros_like(primal)
88  return nest.pack_sequence_as(primals, jac_flat)
89
90
91def _jvp_batch(f, primal, tangents):
92  tf_function = def_function.function(f)
93
94  return control_flow_ops.vectorized_map(
95      functools.partial(_jvp, tf_function, primal), tangents)
96
97
98def _jvp_batch_matmul(f, primals, tangent_batch):
99  """Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
100  jac_fwd = _jacfwd(f, primals)
101
102  def jac_mul(tangent):
103    flat_tangent = array_ops.reshape(tangent, shape=[-1])
104    tangent_vector = array_ops.expand_dims(flat_tangent, 1)
105    jvp_vector = math_ops.matmul(jac_fwd, tangent_vector)
106    return array_ops.reshape(jvp_vector, tangent.shape)
107
108  return control_flow_ops.vectorized_map(jac_mul, tangent_batch)
109
110
111def _grad(f, argnums=0):
112  """Return a function which computes the gradient of `f`."""
113
114  def _f(*params):
115    with backprop.GradientTape() as tape:
116      tape.watch(params)
117      primals_out = f(*params)
118    return tape.gradient(
119        primals_out,
120        params[argnums],
121        unconnected_gradients=UnconnectedGradients.ZERO)
122
123  return _f
124
125
126def _gradfwd(f, argnums=0, f_out_dtypes=dtypes.float32):
127  """Return a function which computes the gradient of `f` in forward mode."""
128
129  def _f(*params):
130
131    def _single_jvp(param_mask):
132      with forwardprop.ForwardAccumulator(
133          primals=[params[argnums]], tangents=param_mask) as acc:
134        primals_out = f(*params)
135      return acc.jvp(primals_out)
136
137    # Building up a function to run with pfor takes a bit too long since we're
138    # only running it a handful of times.
139    return _vectorize_parameters(
140        _single_jvp, [params[argnums]], use_pfor=False, dtype=f_out_dtypes)
141
142  return _f
143
144
145def _hvp(f, primals, tangents):
146  """Compute a forward-over-back Hessian-vector product."""
147  with forwardprop.ForwardAccumulator(primals, tangents) as acc:
148    with backprop.GradientTape() as tape:
149      tape.watch(primals)
150      f_out = f(*primals)
151      f_out.shape.assert_is_compatible_with([])
152    return acc.jvp(tape.gradient(f_out, primals))
153
154
155def _vectorize_parameters(f, params, use_pfor, dtype):
156  """Loop over `params`, providing a one-hot mask to `f` for each."""
157  parameter_sizes = [array_ops.size(param) for param in params]
158  total_size = math_ops.add_n(parameter_sizes)
159
160  def _wrapper(index):
161    full_onehot = array_ops.one_hot(index, total_size)
162    split_onehot = array_ops.split(full_onehot, parameter_sizes)
163    tangents = [
164        array_ops.reshape(v, array_ops.shape(param))
165        for param, v in zip(params, split_onehot)
166    ]
167    return f(tangents)
168
169  if use_pfor:
170    return control_flow_ops.vectorized_map(_wrapper, math_ops.range(total_size))
171
172  return map_fn.map_fn(_wrapper, math_ops.range(total_size), dtype)
173
174
175def _forward_over_back_hessian(f, params, use_pfor, dtype=None):
176  """Computes the full Hessian matrix for the scalar-valued f(*params).
177
178  Args:
179    f: A function taking `params` and returning a scalar.
180    params: A possibly nested structure of tensors.
181    use_pfor: If true, uses `tf.vectorized_map` calls instead of looping.
182    dtype: Required if `use_pfor=False`. A possibly nested structure of dtypes
183      (e.g. `tf.float32`) matching the structure of `f`'s returns.
184
185  Returns:
186    A possibly nested structure of matrix slices corresponding to `params`. Each
187    slice has shape [P, p_s] where `p_s` is the number of parameters (`tf.size`)
188    in the corresponding element of `params` and `P` is the total number of
189    parameters (`sum_s(p_s)`). The full matrix can be obtained by concatenating
190    along the second axis.
191  """
192  return _vectorize_parameters(
193      functools.partial(_hvp, f, params),
194      params,
195      use_pfor=use_pfor,
196      dtype=dtype)
197
198
199def _test_gradients(testcase,
200                    f,
201                    primals,
202                    order,
203                    delta=1e-3,
204                    rtol=1e-2,
205                    atol=1e-6,
206                    srtol=1e-6,
207                    satol=1e-6):
208  """Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients."""
209  if order < 1:
210    raise ValueError(
211        "`order` should be a positive integer, got '{}'.".format(order))
212  if order > 1:
213    _test_gradients(
214        testcase=testcase,
215        f=_grad(f),
216        primals=primals,
217        order=order - 1,
218        delta=delta,
219        rtol=rtol,
220        atol=atol,
221        srtol=srtol,
222        satol=satol)
223  sym_jac_back, num_jac = gradient_checker_v2.compute_gradient(
224      f, primals, delta=delta)
225  testcase.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol)
226  sym_jac_fwd = _jacfwd(f, primals)
227  testcase.assertAllClose(num_jac, sym_jac_fwd, rtol=rtol, atol=atol)
228  # And the symbolic computations should be much closer.
229  testcase.assertAllClose(sym_jac_back, sym_jac_fwd, rtol=srtol, atol=satol)
230
231
232@test_util.with_eager_op_as_function
233class ForwardpropTest(test.TestCase, parameterized.TestCase):
234
235  def testJVPFunction(self):
236    add_outputs = (constant_op.constant(4.),)
237    vp, = forwardprop._jvp_dispatch(
238        op_name="Add",
239        attr_tuple=(),
240        inputs=(constant_op.constant(1.), constant_op.constant(3.)),
241        outputs=add_outputs,
242        tangents=(
243            constant_op.constant(1.),
244            constant_op.constant(5.),
245        ))
246    self.assertAllClose(1. + 5., self.evaluate(vp))
247
248    mul_outputs = (constant_op.constant([20.]),)
249    vp, = forwardprop._jvp_dispatch(
250        op_name="Mul",
251        attr_tuple=(),
252        inputs=(constant_op.constant([4.]), constant_op.constant([5.])),
253        outputs=mul_outputs,
254        tangents=(
255            constant_op.constant([2.]),
256            constant_op.constant([3.]),
257        ))
258    self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp))
259
260  def testJVPFunctionWithBatchOfTangents(self):
261    add_outputs = (constant_op.constant(4.),)
262    jvp_flat = forwardprop._jvp_dispatch(
263        op_name="Add",
264        attr_tuple=(),
265        inputs=(constant_op.constant(1.), constant_op.constant(3.)),
266        outputs=add_outputs,
267        tangents=(
268            constant_op.constant([1., 2., 3.]),
269            constant_op.constant([4., 5., 6.]),
270        ),
271        use_batch=True)
272
273    # Using evaluate and asserting with just a list works too
274    # but the output is more explicit this way
275    self.assertAllClose([constant_op.constant([1. + 4., 2. + 5., 3. + 6.])],
276                        jvp_flat)
277
278    mul_outputs = (constant_op.constant([20.]),)
279    jvp_flat = forwardprop._jvp_dispatch(
280        op_name="Mul",
281        attr_tuple=(),
282        inputs=(constant_op.constant([4.]), constant_op.constant([5.])),
283        outputs=mul_outputs,
284        tangents=(
285            constant_op.constant([[1.], [0.], [1.]]),
286            constant_op.constant([[0.], [1.], [1.]]),
287        ),
288        use_batch=True)
289    self.assertAllClose([constant_op.constant([[5.], [4.], [5. + 4.]])],
290                        jvp_flat)
291
292  def testJVPFunctionRaisesError(self):
293    sum_outputs = (constant_op.constant(6.),)
294
295    with self.assertRaisesRegex(ValueError, r".*was expected to be of shape*"):
296      forwardprop._jvp_dispatch(
297          op_name="Add",
298          attr_tuple=(),
299          inputs=(constant_op.constant(2.), constant_op.constant(4.)),
300          outputs=sum_outputs,
301          tangents=(constant_op.constant([1., 2.]),
302                    constant_op.constant([[1.], [2.]])),
303          use_batch=True)
304
305  def testNonDifferentiableOpWithInputTangent(self):
306    x = constant_op.constant(1.)
307    with forwardprop.ForwardAccumulator(x, 2.) as acc1:
308      with forwardprop.ForwardAccumulator(x, 2.) as acc2:
309        y = array_ops.zeros_like(x)
310      self.assertIsNone(acc1.jvp(y))
311    self.assertIsNone(acc2.jvp(y))
312
313  def testRunFunctionsEagerly(self):
314    try:
315      original_setting = def_function.functions_run_eagerly()
316      def_function.run_functions_eagerly(True)
317      x = constant_op.constant(1.)
318      with forwardprop.ForwardAccumulator(x, 2.) as acc:
319        y = x * 3.
320      self.assertAllClose(6., acc.jvp(y))
321    finally:
322      def_function.run_functions_eagerly(original_setting)
323
324  def testJVPFunctionUsedByAccumulatorForOps(self):
325    previous_fn = forwardprop._jvp_dispatch
326    try:
327      x = constant_op.constant(1.)
328      with forwardprop.ForwardAccumulator(x, 2.) as acc:
329        y = x + x
330        pywrap_tfe.TFE_Py_RegisterJVPFunction(
331            lambda *args, **kwargs: [constant_op.constant(-15.)])
332        z = x + x
333      self.assertAllClose(4., acc.jvp(y))
334      self.assertAllClose(-15., acc.jvp(z))
335    finally:
336      pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn)
337
338  @test_util.assert_no_new_pyobjects_executing_eagerly
339  def testFunctionCacheLimited(self):
340    # Every time this loop is executed, it will create a slightly larger Tensor
341    # and push it through Add's gradient.
342    # We run TRACE_COUNT_LIMIT x 2 so that it is tested with both
343    # experimental_relax_shapes on and off.
344    for execution_count in range(forwardprop._TRACE_COUNT_LIMIT*2):
345      x = array_ops.zeros([execution_count])
346      with forwardprop.ForwardAccumulator(x, array_ops.ones_like(x)) as acc:
347        y = x + x
348      self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y))
349
350  def testVariableUnwatchedZero(self):
351    v = variables.Variable([[1.]])
352    x = constant_op.constant(1.)
353    xt = constant_op.constant(2.)
354    with forwardprop.ForwardAccumulator(x, xt) as acc:
355      pass
356    self.assertIsNone(acc.jvp(v))
357    self.assertAllClose([[0.]], acc.jvp(v, unconnected_gradients="zero"))
358
359  @test_util.assert_no_new_pyobjects_executing_eagerly
360  def testFunctionReturnsResource(self):
361    v = variables.Variable([[1.]])
362    x = constant_op.constant(1.)
363    xt = constant_op.constant(2.)
364
365    @def_function.function
366    def f(a):
367      return a, v.handle
368
369    with forwardprop.ForwardAccumulator(x, xt) as acc:
370      y, _ = f(x)
371    self.assertAllClose(2., acc.jvp(y))
372
373  @test_util.assert_no_new_pyobjects_executing_eagerly
374  def testMultipleWatchesAdd(self):
375    x = constant_op.constant(-2.)
376    with self.assertRaisesRegex(ValueError, "multiple times"):
377      with forwardprop.ForwardAccumulator([x, x], [1., 2.]):
378        pass
379    with forwardprop.ForwardAccumulator([x], [3.]) as acc:
380      self.assertAllClose(3., acc.jvp(x))
381      acc._watch(x, constant_op.constant(10.))
382      self.assertAllClose(13., acc.jvp(x))
383      acc._watch(x, constant_op.constant(11.))
384      self.assertAllClose(24., acc.jvp(x))
385      y = constant_op.constant(3.) * x
386    self.assertAllClose(24., acc.jvp(x))
387    self.assertAllClose(24. * 3., acc.jvp(y))
388
389  @test_util.assert_no_new_pyobjects_executing_eagerly
390  def testReenter(self):
391    x = constant_op.constant(-2.)
392    with forwardprop.ForwardAccumulator(x, 1.5) as acc:
393      self.assertAllClose(1.5, acc.jvp(x))
394      y = 4. * x
395      self.assertAllClose(6., acc.jvp(y))
396      with self.assertRaisesRegex(ValueError, "already recording"):
397        with acc:
398          pass
399    z = 4. * x
400    self.assertIsNone(acc.jvp(z))
401    with acc:
402      yy = y * y
403    self.assertAllClose(6. * -8. * 2., acc.jvp(yy))
404
405  @test_util.assert_no_new_pyobjects_executing_eagerly
406  def testDeadTensorsJVPCleared(self):
407    x = array_ops.ones([100])
408    x_weak = weakref.ref(x)
409    grad_tensor = constant_op.constant(array_ops.zeros([100]))
410    grad_tensor_weak = weakref.ref(grad_tensor)
411    with forwardprop.ForwardAccumulator(x, grad_tensor) as acc:
412      derived_tensor = constant_op.constant(2.) * x
413      del grad_tensor
414      self.assertAllClose(array_ops.zeros([100]), acc.jvp(x))
415      del x
416      self.assertIsNone(x_weak())
417      self.assertIsNone(grad_tensor_weak())
418      derived_tensor_weak = weakref.ref(derived_tensor)
419      derived_tensor_grad = acc.jvp(derived_tensor)
420      derived_tensor_grad_weak = weakref.ref(derived_tensor_grad)
421      del derived_tensor
422      del derived_tensor_grad
423      self.assertIsNone(derived_tensor_weak())
424      self.assertIsNone(derived_tensor_grad_weak())
425
426  @test_util.assert_no_new_pyobjects_executing_eagerly
427  def testJVPManual(self):
428    primal, tangent = _jvp(math_ops.sin, (constant_op.constant(0.1),),
429                           (constant_op.constant(0.2),))
430    self.assertAllClose(math_ops.sin(0.1), primal)
431    self.assertAllClose(math_ops.cos(0.1) * 0.2, tangent)
432
433  @test_util.assert_no_new_pyobjects_executing_eagerly
434  def testNumericHigherOrder(self):
435
436    def f(x):
437      pointwise = math_ops.sin(x) * math_ops.tan(x)
438      return math_ops.reduce_prod(
439          pointwise + math_ops.reduce_sum(pointwise), axis=1)
440
441    if (context.run_eager_op_as_function_enabled() and
442        test_util.is_xla_enabled()):
443      # Autoclustering kicks in when eager_op_as_function is enabled.
444      # Under XLA the symbolic tolerances are less than under TF.
445      # Ref: b/202559426
446      _test_gradients(
447          self,
448          f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])],
449          order=3,
450          srtol=1e-6,
451          satol=1e-3)
452    else:
453      _test_gradients(
454          self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3)
455
456  @test_util.assert_no_new_pyobjects_executing_eagerly
457  def testNumericHigherOrderFloat64(self):
458
459    def f(x):
460      pointwise = math_ops.sin(x) * math_ops.tan(x)
461      return math_ops.reduce_prod(
462          pointwise + math_ops.reduce_sum(pointwise), axis=1)
463
464    _test_gradients(
465        self,
466        f,
467        [constant_op.constant([[2.0, 3.0], [1.0, 4.0]], dtype=dtypes.float64)],
468        order=3)
469
470  @test_util.assert_no_new_pyobjects_executing_eagerly
471  def testCustomGradient(self):
472
473    @custom_gradient.custom_gradient
474    def f(x):
475
476      def grad(dy):
477        return dy * math_ops.cos(x)
478
479      return np.sin(x.numpy()), grad
480
481    _test_gradients(self, f, [constant_op.constant([1., 2.])], order=3)
482
483  # TODO(allenl): investigate why assert_no_new_pyobjects_executing_eagerly
484  # fails around this test?
485  def testExceptionCustomGradientRecomputeGradForward(self):
486
487    @custom_gradient.recompute_grad
488    def f(x):
489      return math_ops.reduce_prod(math_ops.tanh(x)**2)
490
491    with self.assertRaisesRegex(NotImplementedError,
492                                "recompute_grad tried to transpose"):
493      primals = [constant_op.constant([1.])]
494      sym_jac_fwd = _jacfwd(f, primals)
495
496  def testExceptionInCustomGradientNotSwallowed(self):
497
498    @custom_gradient.custom_gradient
499    def f(unused_x):
500
501      def grad(unused_dy):
502        raise ValueError("test_error_string")
503
504      return 1., grad
505
506    c = constant_op.constant(1.)
507    d = constant_op.constant(2.)
508    with forwardprop.ForwardAccumulator(c, d):
509      with self.assertRaisesRegex(ValueError, "test_error_string"):
510        f(c)
511
512  @parameterized.named_parameters([("EluM5", -0.5, nn_ops.elu),
513                                   ("EluP5", [0.5], nn_ops.elu),
514                                   ("SwishP5", 0.5, nn_impl.swish),
515                                   ("SwishM5", [-0.5], nn_impl.swish)])
516  def testElementwiseNNOps(self, value, op_fn):
517    _test_gradients(self, op_fn, [constant_op.constant(value)], order=3)
518
519  def testFusedBatchNormGradsInference(self):
520
521    x_shape = [4, 10, 10, 2]
522    increment = 3. / math_ops.reduce_prod(
523        constant_op.constant(x_shape, dtype=dtypes.float32))
524    x = array_ops.reshape(math_ops.range(-2., 1., increment), x_shape)
525    scale = constant_op.constant([1., 1.1])
526    offset = constant_op.constant([-0.5, -0.6])
527    mean = constant_op.constant([-1.3, 1.4])
528    variance = constant_op.constant([0.7, 0.9])
529    epsilon = 0.001
530
531    def _bn_fused(x_arg, scale_arg, offset_arg):
532      return nn_impl.fused_batch_norm(
533          x_arg,
534          scale_arg,
535          offset_arg,
536          mean,
537          variance,
538          epsilon=epsilon,
539          is_training=False)[0]
540
541    _test_gradients(self, _bn_fused, [x, scale, offset], order=2, atol=1e-2)
542
543  def testPushPopAccumulatorState(self):
544    # Note that this example is somewhat contrived. push_forwardprop_state is
545    # probably only useful in practice for building functions that compute jvps
546    # alongside their usual outputs.
547    c = constant_op.constant(1.)
548    d = constant_op.constant(2.)
549    with forwardprop.ForwardAccumulator(c, d) as acc:
550
551      @custom_gradient.custom_gradient
552      def f(x):
553        y = math_ops.sin(x.numpy())
554
555        def grad(dy):
556          with forwardprop_util.push_forwardprop_state():
557            x_copy = constant_op.constant(x.numpy())
558            acc._watch(x_copy, dy)
559            y_copy = math_ops.sin(x_copy)
560          return dy * acc.jvp(y_copy)
561
562        return y, grad
563
564      output = f(c)
565      self.assertAllClose(d * math_ops.cos(c), acc.jvp(output))
566
567  @parameterized.named_parameters([
568      ("Order{}".format(order), order, expected)
569      for order, expected in enumerate(_X11_35_DERIVATIVES)
570  ])
571  @test_util.assert_no_new_pyobjects_executing_eagerly
572  def testHigherOrderPureForward(self, order, expected):
573
574    def _forwardgrad(f):
575
576      def _compute_forwardgrad(primal):
577        tangent = constant_op.constant(1.)
578        with forwardprop.ForwardAccumulator(primal, tangent) as acc:
579          primal_out = f(primal)
580        return acc.jvp(primal_out)
581
582      return _compute_forwardgrad
583
584    def _forward(x):
585      return x**3.5
586
587    f = _forward
588    primal = constant_op.constant(1.1)
589    for _ in range(order):
590      f = _forwardgrad(f)
591    self.assertAllClose(expected, f(primal))
592
593  @parameterized.named_parameters([("Function", def_function.function),
594                                   ("NoFunction", lambda f: f)])
595  def testGradPureForward(self, decorator):
596
597    @decorator
598    def f(x):
599      return x**3.5
600
601    primal = constant_op.constant(1.1)
602    with forwardprop.ForwardAccumulator(primal,
603                                        constant_op.constant(1.)) as outer_acc:
604      with forwardprop.ForwardAccumulator(primal,
605                                          constant_op.constant(1.)) as acc:
606        primal_out = f(primal)
607    inner_jvp = acc.jvp(primal_out)
608    outer_jvp = outer_acc.jvp(inner_jvp)
609    self.assertAllClose(1.1**3.5, primal_out)
610    self.assertAllClose(3.5 * 1.1**2.5, inner_jvp)
611    self.assertAllClose(3.5 * 2.5 * 1.1**1.5, outer_jvp)
612    self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out)))
613
614  @test_util.assert_no_new_pyobjects_executing_eagerly
615  def testJVPPacking(self):
616    two = constant_op.constant(2.)
617    primal_in = constant_op.constant(1.)
618    inner_jvp = constant_op.constant(3.)
619    with forwardprop.ForwardAccumulator(
620        [primal_in, inner_jvp],
621        [constant_op.constant(2.),
622         constant_op.constant(4.)]) as outer_acc:
623      with forwardprop.ForwardAccumulator(primal_in, inner_jvp) as inner_acc:
624        packed_input_indices, packed_input_tangents = (
625            forwardprop_util.pack_tangents([primal_in]))
626        self.assertAllClose([3., 2., 4.], packed_input_tangents)
627        expected_indices = (
628            # inner_acc watches primal_in
629            (
630                (0, 1),),
631            # outer_acc watches primal_in and inner_jvp
632            ((0, 2), (1, 3)))
633        self.assertAllEqual(expected_indices, packed_input_indices)
634        primal_out = primal_in * two
635        self.assertAllClose(6., inner_acc.jvp(primal_out))
636        self.assertAllClose(4., outer_acc.jvp(primal_out))
637        self.assertAllClose(8., outer_acc.jvp(inner_acc.jvp(primal_out)))
638        packed_output_indices, packed_output_tangents = (
639            forwardprop_util.pack_tangents([primal_out]))
640        self.assertAllClose([6., 4., 8.], packed_output_tangents)
641        self.assertAllEqual(expected_indices, packed_output_indices)
642
643  def testFunctionGradInFunctionPureForward(self):
644
645    @def_function.function
646    def take_gradients():
647
648      @def_function.function
649      def f(x):
650        return x**3.5
651
652      primal = constant_op.constant(1.1)
653      with forwardprop.ForwardAccumulator(
654          primal, constant_op.constant(1.)) as outer_acc:
655        with forwardprop.ForwardAccumulator(primal,
656                                            constant_op.constant(1.)) as acc:
657          primal_out = f(primal)
658      inner_jvp = acc.jvp(primal_out)
659      outer_jvp = outer_acc.jvp(inner_jvp)
660      self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out)))
661      return primal_out, inner_jvp, outer_jvp
662
663    primal_out, inner_jvp, outer_jvp = take_gradients()
664    self.assertAllClose(1.1**3.5, primal_out)
665    self.assertAllClose(3.5 * 1.1**2.5, inner_jvp)
666    self.assertAllClose(3.5 * 2.5 * 1.1**1.5, outer_jvp)
667
668  def testFunctionGrad(self):
669
670    @def_function.function
671    def f(x):
672      return math_ops.reduce_prod(math_ops.tanh(x)**2)
673
674    _test_gradients(self, f, [constant_op.constant([1., 2.])], order=3)
675
676  def testReusingJVP(self):
677    m1 = random_ops.random_uniform((256, 2096))
678    m2 = array_ops.identity(m1)
679    tangent1 = random_ops.random_uniform((256, 2096))
680    tangent2 = random_ops.random_uniform((256, 2096))
681    matmul = def_function.function(math_ops.matmul)
682
683    with forwardprop.ForwardAccumulator(
684        primals=[m1, m2], tangents=[tangent1, tangent2]) as acc:
685      result1 = matmul(m1, m1, transpose_b=True)
686      result2 = matmul(m2, m2, transpose_b=True)
687
688    def _expected(mat, tangent):
689      return (math_ops.matmul(tangent, mat, transpose_b=True) +
690              math_ops.matmul(mat, tangent, transpose_b=True))
691
692    self.assertAllClose(result1, result2)
693    self.assertAllClose(_expected(m1, tangent1), acc.jvp(result1))
694    self.assertAllClose(_expected(m2, tangent2), acc.jvp(result2))
695
696  @test_util.assert_no_new_pyobjects_executing_eagerly
697  def testHVPMemory(self):
698
699    def fun(x):
700      return math_ops.reduce_prod(math_ops.tanh(x)**2)
701
702    primals = constant_op.constant([1., 2., 3.])
703    tangents = constant_op.constant([3., 4., 5.])
704    _hvp(fun, (primals,), (tangents,))
705
706  @test_util.assert_no_new_pyobjects_executing_eagerly
707  def testHVPCorrectness(self):
708
709    def fun(x):
710      return math_ops.reduce_prod(math_ops.tanh(x)**2)
711
712    primals = constant_op.constant([1., 2., 3.])
713    tangents = constant_op.constant([3., 4., 5.])
714    forwardback_hvp_eager, = _hvp(fun, (primals,), (tangents,))
715    forwardback_hvp_function, = def_function.function(_hvp)(fun, (primals,),
716                                                            (tangents,))
717
718    with backprop.GradientTape(persistent=True) as g:
719      g.watch(primals)
720      with backprop.GradientTape() as gg:
721        gg.watch(primals)
722        out = fun(primals)
723      grad = array_ops.unstack(gg.gradient(out, primals))
724    hessian = []
725    for i in range(3):
726      hessian.append(g.gradient(grad[i], primals))
727    hessian = array_ops.stack(hessian, axis=0)
728    backback_hvp = math_ops.tensordot(hessian, tangents, axes=1)
729
730    self.assertAllClose(backback_hvp, forwardback_hvp_eager)
731    self.assertAllClose(backback_hvp, forwardback_hvp_function)
732
733  @test_util.assert_no_new_pyobjects_executing_eagerly
734  def testShouldRecordAndStopRecord(self):
735    c = constant_op.constant(1.)
736    c_tangent = constant_op.constant(2.)
737    with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
738      with backprop.GradientTape() as tape:
739        self.assertFalse(tape_lib.should_record_backprop([c]))
740        self.assertEqual(1, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
741        tape.watch(c)
742        self.assertEqual(2, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
743        self.assertTrue(tape_lib.should_record_backprop([c]))
744        with tape_lib.stop_recording():
745          self.assertEqual(0,
746                           pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
747          self.assertFalse(tape_lib.should_record_backprop([c]))
748          d = c * 2.
749        self.assertEqual(2, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
750        self.assertTrue(tape_lib.should_record_backprop([c]))
751        self.assertFalse(tape_lib.should_record_backprop([d]))
752        self.assertIsNone(acc.jvp(d))
753      self.assertIsNone(tape.gradient(d, c))
754
755  @test_util.assert_no_new_pyobjects_executing_eagerly
756  def testRecordingSelectively(self):
757    c = constant_op.constant(1.)
758    c_tangent = constant_op.constant(2.)
759    with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
760      with backprop.GradientTape(persistent=True) as tape:
761        tape.watch(c)
762        with tape_lib.stop_recording():
763          two = constant_op.constant(2.)
764          d = c * two
765          three = constant_op.constant(3.)
766          e = c * three
767        self.assertIsNone(acc.jvp(d))
768        self.assertIsNone(acc.jvp(e))
769        self.assertIsNone(tape.gradient(d, c))
770        self.assertIsNone(tape.gradient(e, c))
771        tape_lib.record_operation_forwardprop_only(
772            "CustomForwardMul", [d], [c, two], lambda dd: (two * dd, c * dd),
773            None)
774        tape_lib.record_operation_backprop_only("CustomBackwardMul", [e],
775                                                [c, three], lambda de:
776                                                (three * de, c * de))
777        self.assertAllClose(4., acc.jvp(d))
778        self.assertIsNone(acc.jvp(e))
779        self.assertIsNone(tape.gradient(d, c))
780        self.assertAllClose(3., tape.gradient(e, c))
781
782  @test_util.assert_no_new_pyobjects_executing_eagerly
783  def testOpWithNoTrainableOutputs(self):
784    v = variables.Variable(1.)
785    with forwardprop.ForwardAccumulator(v, 11.):
786      v.assign_sub(0.5)
787      self.assertAllClose(0.5, self.evaluate(v))
788
789  # TODO(b/141025187): Add a no_new_pyobjects decorator.
790  def testVariableReadInFunction(self):
791    v = variables.Variable(1.)
792    with forwardprop.ForwardAccumulator(v, 11.) as acc:
793
794      @def_function.function
795      def f():
796        return v.read_value(), 2. * v.read_value()
797
798      result = f()
799      self.assertAllClose((1.0, 2.), result)
800      self.assertAllClose((11., 22.), acc.jvp(result))
801
802  @parameterized.named_parameters([("ForwardPropFirst", True),
803                                   ("TapeFirst", False)])
804  def testForwardOverBackwardMemoryEfficiency(self, forward_prop_first):
805    # Watching depends on nesting, not creation order
806    c = constant_op.constant(1.)
807    if forward_prop_first:
808      forward_accumulator = forwardprop.ForwardAccumulator(c, .1)
809      gradient_tape = backprop.GradientTape()
810    else:
811      gradient_tape = backprop.GradientTape()
812      forward_accumulator = forwardprop.ForwardAccumulator(c, .1)
813    try:
814      gc.disable()
815      with gradient_tape as tape:
816        # Adding and removing the tape multiple times in different nesting
817        # patterns does not affect watch ordering.
818        pass
819      with forward_accumulator as acc:
820        with gradient_tape as tape:
821          tape.watch(c)
822          d = math_ops.cos(c)
823          self.assertFalse(tape_lib.should_record_backprop((acc.jvp(d),)))
824          e = math_ops.cos(acc.jvp(d))
825          math_ops.cos(e)
826          weak_e = weakref.ref(e)
827          del e
828          self.assertIsNone(weak_e())
829        self.assertIsNone(tape.gradient(acc.jvp(d), c))
830    finally:
831      gc.enable()
832
833  @parameterized.named_parameters([("ForwardPropFirst", True),
834                                   ("TapeFirst", False)])
835  def testBackwardOverForward(self, forward_prop_first):
836    c = constant_op.constant(1.)
837    # Watching depends on nesting, not creation order
838    if forward_prop_first:
839      forward_accumulator = forwardprop.ForwardAccumulator(c, .1)
840      gradient_tape = backprop.GradientTape()
841    else:
842      gradient_tape = backprop.GradientTape()
843      forward_accumulator = forwardprop.ForwardAccumulator(c, .1)
844    with gradient_tape as tape:
845      with forward_accumulator as acc:
846        tape.watch(c)
847        d = math_ops.cos(c)
848        self.assertTrue(tape_lib.should_record_backprop((acc.jvp(d),)))
849      self.assertAllClose(-.1 * math_ops.cos(1.), tape.gradient(acc.jvp(d), c))
850
851  @test_util.assert_no_new_pyobjects_executing_eagerly
852  def testRecordingWithJVPIndices(self):
853    c = constant_op.constant(1.)
854    with forwardprop.ForwardAccumulator(c, 10.) as acc:
855      packed_input_tangents = forwardprop_util.pack_tangents([c]).tangents
856      self.assertAllClose([10.], packed_input_tangents)
857      d = constant_op.constant(2.)
858      d_tangent = constant_op.constant(3.)
859      tape_lib.record_operation_forwardprop_only("FunctionWithInlineJVPs",
860                                                 [d] + [d_tangent],
861                                                 [c] + packed_input_tangents,
862                                                 None, (((0, 1),),))
863      self.assertAllClose(3., acc.jvp(d))
864
865  @test_util.assert_no_new_pyobjects_executing_eagerly
866  def testSpecialForwardFunctionUsed(self):
867    c = constant_op.constant(1.)
868    d = constant_op.constant(2.)
869    e = constant_op.constant(3.)
870    with forwardprop.ForwardAccumulator(c, 10.) as acc:
871      tape_lib.record_operation("ForwardIsSpecial", [d], [c], None,
872                                lambda jvp: [-2. * jvp])
873      self.assertAllClose(-20., acc.jvp(d))
874      tape_lib.record_operation("ForwardIsSpecial2", [], [], None, lambda: [])
875      tape_lib.record_operation("ForwardIsSpecial3", [e], [d], None,
876                                lambda x: [x])
877      self.assertAllClose(-20., acc.jvp(e))
878
879  @test_util.assert_no_new_pyobjects_executing_eagerly
880  def testVariableWatched(self):
881    v = variables.Variable([1., 2., 3.])
882    with forwardprop.ForwardAccumulator(v, constant_op.constant([.1, -.2,
883                                                                 .3])) as acc:
884      self.assertAllClose([.1, -.2, .3], acc.jvp(v))
885      x = v * 2.
886      self.assertAllClose([.2, -.4, .6], acc.jvp(x))
887      x2 = v + .1
888      self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
889
890  def testUnconnectedGradients(self):
891    x = constant_op.constant(-1.)
892    with forwardprop.ForwardAccumulator(x, 0.1) as acc:
893      self.assertAllClose(0.1, acc.jvp(x, unconnected_gradients="zero"))
894      self.assertAllClose(0.1, acc.jvp(x, unconnected_gradients="none"))
895      y = constant_op.constant(-2.)
896      self.assertAllClose(0.0, acc.jvp(y, unconnected_gradients="zero"))
897      self.assertIsNone(acc.jvp(y, unconnected_gradients="none"))
898
899  # TODO(kkb): One weakref instance is created with warmup_iters=2,
900  # investigate.
901  @test_util.assert_no_new_pyobjects_executing_eagerly(warmup_iters=3)
902  def testVariableWatchedFunction(self):
903
904    class _Model(module.Module):
905
906      def __init__(self):
907        self._v = None
908
909      @def_function.function
910      def compute_jvps(self):
911        if self._v is None:
912          self._v = variables.Variable([1., 2., 3.])
913        with forwardprop.ForwardAccumulator(self._v,
914                                            constant_op.constant([.1, -.2,
915                                                                  .3])) as acc:
916          x = self._v * 2.
917          x2 = self._v + .1
918        return acc.jvp((self._v, x, x2))
919
920    model = _Model()
921    v_jvp, x_jvp, x2_jvp = model.compute_jvps()
922    self.assertAllClose([.1, -.2, .3], v_jvp)
923    self.assertAllClose([.2, -.4, .6], x_jvp)
924    self.assertAllClose([.1, -.2, .3], x2_jvp)
925
926  def testIndexSlicesGrad(self):
927    x = constant_op.constant([1.])
928
929    with forwardprop.ForwardAccumulator(x, constant_op.constant([3.])) as acc:
930      y = array_ops.gather(x, 0)
931    self.assertAllClose(3., acc.jvp(y))
932
933  def testIndexSlicesGradInFunction(self):
934
935    @def_function.function
936    def f(a):
937      return array_ops.gather(a, 0)
938
939    x = constant_op.constant([1.])
940
941    with forwardprop.ForwardAccumulator(x, constant_op.constant([3.])) as acc:
942      y = f(x)
943    self.assertAllClose(3., acc.jvp(y))
944
945  # NOTE: assert_no_new_pyobjects_executing_eagerly fails flakily on this
946  # test... could be something wrong with the test decorator, or some sort of
947  # nondeterministic caching.
948  def testMirroredVariableWatched(self):
949
950    def _replicated(input_tangent):
951      with forwardprop.ForwardAccumulator(v, input_tangent) as acc:
952        self.assertAllClose([.1, -.2, .3], acc.jvp(v))
953        x = v * 2.
954        self.assertAllClose([.2, -.4, .6], acc.jvp(x))
955        x2 = v + .1
956        self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
957
958    strategy = mirrored_strategy.MirroredStrategy()
959    with strategy.scope():
960      v = variables.Variable([1., 2., 3.])
961      strategy.run(_replicated, args=(constant_op.constant([.1, -.2, .3]),))
962
963  # TODO(b/141025187): Add a no_new_pyobjects decorator.
964  def testArgumentUnused(self):
965    v = constant_op.constant(1.)
966    with forwardprop.ForwardAccumulator(v, 11.) as acc:
967
968      @def_function.function
969      def _f(x):
970        del x
971        return constant_op.constant(1.)
972
973      result = _f(v)
974      self.assertAllClose(1.0, result)
975      self.assertIsNone(acc.jvp(result))
976
977
978@def_function.function
979def _has_loop(iters, y):
980  ret = 0.
981  for i in math_ops.range(iters):
982    ret += y * math_ops.cast(i, dtypes.float32)
983  return ret
984
985
986@def_function.function
987def _has_cond(k, y):
988  if k > 1:
989    ret = 3. * y
990  else:
991    ret = 0.
992  return ret
993
994
995@def_function.function
996def _fprop_while(iters, y):
997  with forwardprop.ForwardAccumulator(y, 1.) as acc:
998    ret = 0.
999    for i in math_ops.range(iters):
1000      ret += y * math_ops.cast(i, dtypes.float32)
1001  return acc.jvp(ret)
1002
1003
1004@def_function.function
1005def _fprop_cond(k, y):
1006  with forwardprop.ForwardAccumulator(y, 1.) as acc:
1007    if k > 1:
1008      ret = 3. * y
1009    else:
1010      ret = 0.
1011  return acc.jvp(ret)
1012
1013
1014class ControlFlowTests(test.TestCase):
1015
1016  @test_util.assert_no_new_pyobjects_executing_eagerly
1017  def testOfFunctionWhile(self):
1018    y = constant_op.constant(1.)
1019    with forwardprop.ForwardAccumulator(y, 1.) as acc:
1020      self.assertAllClose(10., acc.jvp(_has_loop(constant_op.constant(5), y)))
1021
1022  @test_util.assert_no_new_pyobjects_executing_eagerly
1023  def testOfFunctionCond(self):
1024    y = constant_op.constant(1.)
1025    with forwardprop.ForwardAccumulator(y, 1.) as acc:
1026      self.assertAllClose(3., acc.jvp(_has_cond(constant_op.constant(5), y)))
1027      self.assertAllClose(0., acc.jvp(_has_cond(constant_op.constant(0), y)))
1028
1029  @test_util.assert_no_new_pyobjects_executing_eagerly
1030  def testInFunctionWhile(self):
1031    self.assertAllClose(
1032        10., _fprop_while(constant_op.constant(5), constant_op.constant(1.)))
1033
1034  @test_util.assert_no_new_pyobjects_executing_eagerly
1035  def testInFunctionCond(self):
1036    self.assertAllClose(
1037        3., _fprop_cond(constant_op.constant(5), constant_op.constant(1.)))
1038    self.assertAllClose(
1039        0., _fprop_cond(constant_op.constant(0), constant_op.constant(1.)))
1040
1041
1042class HessianTests(test.TestCase, parameterized.TestCase):
1043
1044  def testHessian1D(self):
1045    # Note: stolen from ops/gradients_test.py
1046    m = 4
1047    rng = np.random.RandomState([1, 2, 3])
1048    mat_value = rng.randn(m, m).astype("float32")
1049    x_value = rng.randn(m).astype("float32")
1050    hess_value = mat_value + mat_value.T
1051    mat = variables.Variable(mat_value)
1052
1053    def _f(x):
1054      return math_ops.reduce_sum(x[:, None] * mat * x[None, :])
1055
1056    hessian_eager, = _forward_over_back_hessian(
1057        _f, [constant_op.constant(x_value)],
1058        use_pfor=False,
1059        dtype=[dtypes.float32])
1060    self.assertAllClose(hess_value, hessian_eager)
1061    hessian_function, = def_function.function(_forward_over_back_hessian)(
1062        _f, [constant_op.constant(x_value)],
1063        use_pfor=False,
1064        dtype=[dtypes.float32])
1065    self.assertAllClose(hess_value, hessian_function)
1066    hessian_pfor, = def_function.function(_forward_over_back_hessian)(
1067        _f, [constant_op.constant(x_value)],
1068        use_pfor=True,
1069        dtype=[dtypes.float32])
1070    self.assertAllClose(hess_value, hessian_pfor)
1071
1072
1073class BatchTests(test.TestCase, parameterized.TestCase):
1074
1075  @parameterized.parameters([(math_ops.sin, (2, 3), 5),
1076                             (math_ops.sin, (2, 3, 4), 10)])
1077  def testJVPBatchCorrectness(self, f, primal_shape, batch_size):
1078    primals = [random_ops.random_uniform(primal_shape)]
1079    tangent_batch = [random_ops.random_uniform([batch_size, *primal_shape])]
1080    self.assertAllClose(
1081        _jvp_batch(f, primals, tangent_batch)[1],
1082        _jvp_batch_matmul(f, primals, *tangent_batch))
1083
1084  def testBatchCorrectness(self):
1085    x = constant_op.constant(2.0)
1086    y = constant_op.constant(5.0)
1087    tangents = (
1088        constant_op.constant([1., 0., 1.]),
1089        constant_op.constant([0., 1., 1.]),
1090    )
1091    with forwardprop.ForwardAccumulator._batch_accumulator((x, y),
1092                                                           tangents) as acc:
1093      z = x * y
1094    self.assertAllClose(acc.jvp(z), constant_op.constant([5.0, 2.0, 7.0]))
1095
1096  @parameterized.named_parameters([("ForwardPropFirst", True),
1097                                   ("TapeFirst", False)])
1098  def testBatchBackwardOverForward(self, forward_prop_first):
1099    x = constant_op.constant(1.)
1100    tangents = random_ops.random_normal(shape=[10], seed=1)
1101    expected = [-t * math_ops.cos(1.) for t in tangents]
1102    if forward_prop_first:
1103      batch_acc = forwardprop.ForwardAccumulator._batch_accumulator(x, tangents)
1104      gradient_tape = backprop.GradientTape(persistent=True)
1105    else:
1106      gradient_tape = backprop.GradientTape(persistent=True)
1107      batch_acc = forwardprop.ForwardAccumulator._batch_accumulator(x, tangents)
1108    with gradient_tape as tape:
1109      with batch_acc as acc:
1110        tape.watch(x)
1111        y = math_ops.cos(x)
1112        self.assertTrue(tape_lib.should_record_backprop((acc.jvp(y),)))
1113        jvps = acc.jvp(y)
1114      d2y_dx2 = [tape.gradient(dy_dx, x) for dy_dx in jvps]
1115    self.assertAllClose(expected, d2y_dx2)
1116
1117
1118if __name__ == "__main__":
1119  # TODO(allenl): Also test with 1.x-style graph mode.
1120  ops.enable_eager_execution()
1121  test.main()
1122