1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21import gc 22import weakref 23 24from absl.testing import parameterized 25import numpy as np 26 27from tensorflow.python import pywrap_tfe 28from tensorflow.python.distribute import mirrored_strategy 29from tensorflow.python.eager import backprop 30from tensorflow.python.eager import def_function 31from tensorflow.python.eager import forwardprop 32from tensorflow.python.eager import forwardprop_util 33from tensorflow.python.eager import tape as tape_lib 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import test_util 38from tensorflow.python.module import module 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import custom_gradient 41from tensorflow.python.ops import gradient_checker_v2 42from tensorflow.python.ops import map_fn 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import nn_impl 45from tensorflow.python.ops import nn_ops 46from tensorflow.python.ops import random_ops 47from tensorflow.python.ops import variables 48from tensorflow.python.ops.parallel_for import control_flow_ops 49from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 50from tensorflow.python.platform import test 51from tensorflow.python.util import nest 52 53_X11_35_DERIVATIVES = [ 54 1.1**3.5, 3.5 * 1.1**2.5, 3.5 * 2.5 * 1.1**1.5, 3.5 * 2.5 * 1.5 * 1.1**0.5 55] 56 57 58# TODO(allenl): Move this somewhere useful once forward gradients are stable. 59def _jvp(f, primals, tangents): 60 """Compute the jacobian of `f` at `primals` multiplied by `tangents`.""" 61 with forwardprop.ForwardAccumulator(primals, tangents) as acc: 62 primals_out = f(*primals) 63 return primals_out, acc.jvp( 64 primals_out, unconnected_gradients=UnconnectedGradients.ZERO) 65 66 67def _jacfwd(f, primals): 68 """Compute the jacobian of `f` at `primals` using forward-mode autodiff.""" 69 jac_flat = [] 70 flat_primals = nest.flatten(primals) 71 tangent_mask = [array_ops.zeros_like(primal) for primal in flat_primals] 72 for primal_index, primal in enumerate(flat_primals): 73 primal_vector = array_ops.reshape(primal, [-1]) 74 primal_vector_length = array_ops.size(primal_vector) 75 jac_columns = [] 76 for element_index in math_ops.range(primal_vector_length): 77 mask = array_ops.one_hot(element_index, primal_vector_length) 78 tangent_mask[primal_index] = array_ops.reshape(mask, 79 array_ops.shape(primal)) 80 jac_columns.append( 81 nest.map_structure( 82 functools.partial(array_ops.reshape, shape=[-1]), 83 _jvp(f, primals, nest.pack_sequence_as(primals, 84 tangent_mask))[1])) 85 jac_flat.append(array_ops.stack(jac_columns, axis=1)) 86 tangent_mask[primal_index] = array_ops.zeros_like(primal) 87 return nest.pack_sequence_as(primals, jac_flat) 88 89 90def _jvp_batch(f, primal, tangents): 91 tf_function = def_function.function(f) 92 93 return control_flow_ops.vectorized_map( 94 functools.partial(_jvp, tf_function, primal), tangents) 95 96 97def _jvp_batch_matmul(f, primals, tangent_batch): 98 """Compute the jacobian of `f` at `primals` multiplied by `tangents`.""" 99 jac_fwd = _jacfwd(f, primals) 100 101 def jac_mul(tangent): 102 flat_tangent = array_ops.reshape(tangent, shape=[-1]) 103 tangent_vector = array_ops.expand_dims(flat_tangent, 1) 104 jvp_vector = math_ops.matmul(jac_fwd, tangent_vector) 105 return array_ops.reshape(jvp_vector, tangent.shape) 106 107 return control_flow_ops.vectorized_map(jac_mul, tangent_batch) 108 109 110def _grad(f, argnums=0): 111 """Return a function which computes the gradient of `f`.""" 112 113 def _f(*params): 114 with backprop.GradientTape() as tape: 115 tape.watch(params) 116 primals_out = f(*params) 117 return tape.gradient( 118 primals_out, 119 params[argnums], 120 unconnected_gradients=UnconnectedGradients.ZERO) 121 122 return _f 123 124 125def _gradfwd(f, argnums=0, f_out_dtypes=dtypes.float32): 126 """Return a function which computes the gradient of `f` in forward mode.""" 127 128 def _f(*params): 129 130 def _single_jvp(param_mask): 131 with forwardprop.ForwardAccumulator( 132 primals=[params[argnums]], tangents=param_mask) as acc: 133 primals_out = f(*params) 134 return acc.jvp(primals_out) 135 136 # Building up a function to run with pfor takes a bit too long since we're 137 # only running it a handful of times. 138 return _vectorize_parameters( 139 _single_jvp, [params[argnums]], use_pfor=False, dtype=f_out_dtypes) 140 141 return _f 142 143 144def _hvp(f, primals, tangents): 145 """Compute a forward-over-back Hessian-vector product.""" 146 with forwardprop.ForwardAccumulator(primals, tangents) as acc: 147 with backprop.GradientTape() as tape: 148 tape.watch(primals) 149 f_out = f(*primals) 150 f_out.shape.assert_is_compatible_with([]) 151 return acc.jvp(tape.gradient(f_out, primals)) 152 153 154def _vectorize_parameters(f, params, use_pfor, dtype): 155 """Loop over `params`, providing a one-hot mask to `f` for each.""" 156 parameter_sizes = [array_ops.size(param) for param in params] 157 total_size = math_ops.add_n(parameter_sizes) 158 159 def _wrapper(index): 160 full_onehot = array_ops.one_hot(index, total_size) 161 split_onehot = array_ops.split(full_onehot, parameter_sizes) 162 tangents = [ 163 array_ops.reshape(v, array_ops.shape(param)) 164 for param, v in zip(params, split_onehot) 165 ] 166 return f(tangents) 167 168 if use_pfor: 169 return control_flow_ops.vectorized_map(_wrapper, math_ops.range(total_size)) 170 171 return map_fn.map_fn(_wrapper, math_ops.range(total_size), dtype) 172 173 174def _forward_over_back_hessian(f, params, use_pfor, dtype=None): 175 """Computes the full Hessian matrix for the scalar-valued f(*params). 176 177 Args: 178 f: A function taking `params` and returning a scalar. 179 params: A possibly nested structure of tensors. 180 use_pfor: If true, uses `tf.vectorized_map` calls instead of looping. 181 dtype: Required if `use_pfor=False`. A possibly nested structure of dtypes 182 (e.g. `tf.float32`) matching the structure of `f`'s returns. 183 184 Returns: 185 A possibly nested structure of matrix slices corresponding to `params`. Each 186 slice has shape [P, p_s] where `p_s` is the number of parameters (`tf.size`) 187 in the corresponding element of `params` and `P` is the total number of 188 parameters (`sum_s(p_s)`). The full matrix can be obtained by concatenating 189 along the second axis. 190 """ 191 return _vectorize_parameters( 192 functools.partial(_hvp, f, params), 193 params, 194 use_pfor=use_pfor, 195 dtype=dtype) 196 197 198def _test_gradients(testcase, 199 f, 200 primals, 201 order, 202 delta=1e-3, 203 rtol=1e-2, 204 atol=1e-6): 205 """Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients.""" 206 if order < 1: 207 raise ValueError( 208 "`order` should be a positive integer, got '{}'.".format(order)) 209 if order > 1: 210 _test_gradients( 211 testcase=testcase, 212 f=_grad(f), 213 primals=primals, 214 order=order - 1, 215 delta=delta, 216 rtol=rtol, 217 atol=atol) 218 sym_jac_back, num_jac = gradient_checker_v2.compute_gradient( 219 f, primals, delta=delta) 220 testcase.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol) 221 sym_jac_fwd = _jacfwd(f, primals) 222 testcase.assertAllClose(num_jac, sym_jac_fwd, rtol=rtol, atol=atol) 223 # And the symbolic computations should be much closer. 224 testcase.assertAllClose(sym_jac_back, sym_jac_fwd) 225 226 227class ForwardpropTest(test.TestCase, parameterized.TestCase): 228 229 def testJVPFunction(self): 230 add_outputs = (constant_op.constant(4.),) 231 vp, = forwardprop._jvp_dispatch( 232 op_name="Add", 233 attr_tuple=(), 234 inputs=(constant_op.constant(1.), constant_op.constant(3.)), 235 outputs=add_outputs, 236 tangents=( 237 constant_op.constant(1.), 238 constant_op.constant(5.), 239 )) 240 self.assertAllClose(1. + 5., self.evaluate(vp)) 241 242 mul_outputs = (constant_op.constant([20.]),) 243 vp, = forwardprop._jvp_dispatch( 244 op_name="Mul", 245 attr_tuple=(), 246 inputs=(constant_op.constant([4.]), constant_op.constant([5.])), 247 outputs=mul_outputs, 248 tangents=( 249 constant_op.constant([2.]), 250 constant_op.constant([3.]), 251 )) 252 self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp)) 253 254 def testJVPFunctionWithBatchOfTangents(self): 255 add_outputs = (constant_op.constant(4.),) 256 jvp_flat = forwardprop._jvp_dispatch( 257 op_name="Add", 258 attr_tuple=(), 259 inputs=(constant_op.constant(1.), constant_op.constant(3.)), 260 outputs=add_outputs, 261 tangents=( 262 constant_op.constant([1., 2., 3.]), 263 constant_op.constant([4., 5., 6.]), 264 ), 265 use_batch=True) 266 267 # Using evaluate and asserting with just a list works too 268 # but the output is more explicit this way 269 self.assertAllClose([constant_op.constant([1. + 4., 2. + 5., 3. + 6.])], 270 jvp_flat) 271 272 mul_outputs = (constant_op.constant([20.]),) 273 jvp_flat = forwardprop._jvp_dispatch( 274 op_name="Mul", 275 attr_tuple=(), 276 inputs=(constant_op.constant([4.]), constant_op.constant([5.])), 277 outputs=mul_outputs, 278 tangents=( 279 constant_op.constant([[1.], [0.], [1.]]), 280 constant_op.constant([[0.], [1.], [1.]]), 281 ), 282 use_batch=True) 283 self.assertAllClose([constant_op.constant([[5.], [4.], [5. + 4.]])], 284 jvp_flat) 285 286 def testJVPFunctionRaisesError(self): 287 sum_outputs = (constant_op.constant(6.),) 288 289 with self.assertRaisesRegex(ValueError, r".*was expected to be of shape*"): 290 forwardprop._jvp_dispatch( 291 op_name="Add", 292 attr_tuple=(), 293 inputs=(constant_op.constant(2.), constant_op.constant(4.)), 294 outputs=sum_outputs, 295 tangents=(constant_op.constant([1., 2.]), 296 constant_op.constant([[1.], [2.]])), 297 use_batch=True) 298 299 def testNonDifferentiableOpWithInputTangent(self): 300 x = constant_op.constant(1.) 301 with forwardprop.ForwardAccumulator(x, 2.) as acc1: 302 with forwardprop.ForwardAccumulator(x, 2.) as acc2: 303 y = array_ops.zeros_like(x) 304 self.assertIsNone(acc1.jvp(y)) 305 self.assertIsNone(acc2.jvp(y)) 306 307 def testRunFunctionsEagerly(self): 308 try: 309 original_setting = def_function.functions_run_eagerly() 310 def_function.run_functions_eagerly(True) 311 x = constant_op.constant(1.) 312 with forwardprop.ForwardAccumulator(x, 2.) as acc: 313 y = x * 3. 314 self.assertAllClose(6., acc.jvp(y)) 315 finally: 316 def_function.run_functions_eagerly(original_setting) 317 318 def testJVPFunctionUsedByAccumulatorForOps(self): 319 previous_fn = forwardprop._jvp_dispatch 320 try: 321 x = constant_op.constant(1.) 322 with forwardprop.ForwardAccumulator(x, 2.) as acc: 323 y = x + x 324 pywrap_tfe.TFE_Py_RegisterJVPFunction( 325 lambda *args, **kwargs: [constant_op.constant(-15.)]) 326 z = x + x 327 self.assertAllClose(4., acc.jvp(y)) 328 self.assertAllClose(-15., acc.jvp(z)) 329 finally: 330 pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn) 331 332 @test_util.assert_no_new_pyobjects_executing_eagerly 333 def testFunctionCacheLimited(self): 334 # Every time this test is executed, it will create a slightly larger Tensor 335 # and push it through Add's gradient. Since we check for new pyobjects after 336 # the warmup, retracing each time without cleaning up old traces fails the 337 # test. It works because of experimental_relax_shapes. 338 for _ in range(forwardprop._TRACE_COUNT_LIMIT): 339 execution_count = getattr(self, "_execution_count", 0) 340 self._execution_count = execution_count + 1 341 x = array_ops.zeros([execution_count]) 342 with forwardprop.ForwardAccumulator(x, array_ops.ones_like(x)) as acc: 343 y = x + x 344 self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y)) 345 346 def testVariableUnwatchedZero(self): 347 v = variables.Variable([[1.]]) 348 x = constant_op.constant(1.) 349 xt = constant_op.constant(2.) 350 with forwardprop.ForwardAccumulator(x, xt) as acc: 351 pass 352 self.assertIsNone(acc.jvp(v)) 353 self.assertAllClose([[0.]], acc.jvp(v, unconnected_gradients="zero")) 354 355 @test_util.assert_no_new_pyobjects_executing_eagerly 356 def testFunctionReturnsResource(self): 357 v = variables.Variable([[1.]]) 358 x = constant_op.constant(1.) 359 xt = constant_op.constant(2.) 360 361 @def_function.function 362 def f(a): 363 return a, v.handle 364 365 with forwardprop.ForwardAccumulator(x, xt) as acc: 366 y, _ = f(x) 367 self.assertAllClose(2., acc.jvp(y)) 368 369 @test_util.assert_no_new_pyobjects_executing_eagerly 370 def testMultipleWatchesAdd(self): 371 x = constant_op.constant(-2.) 372 with self.assertRaisesRegex(ValueError, "multiple times"): 373 with forwardprop.ForwardAccumulator([x, x], [1., 2.]): 374 pass 375 with forwardprop.ForwardAccumulator([x], [3.]) as acc: 376 self.assertAllClose(3., acc.jvp(x)) 377 acc._watch(x, constant_op.constant(10.)) 378 self.assertAllClose(13., acc.jvp(x)) 379 acc._watch(x, constant_op.constant(11.)) 380 self.assertAllClose(24., acc.jvp(x)) 381 y = constant_op.constant(3.) * x 382 self.assertAllClose(24., acc.jvp(x)) 383 self.assertAllClose(24. * 3., acc.jvp(y)) 384 385 @test_util.assert_no_new_pyobjects_executing_eagerly 386 def testReenter(self): 387 x = constant_op.constant(-2.) 388 with forwardprop.ForwardAccumulator(x, 1.5) as acc: 389 self.assertAllClose(1.5, acc.jvp(x)) 390 y = 4. * x 391 self.assertAllClose(6., acc.jvp(y)) 392 with self.assertRaisesRegex(ValueError, "already recording"): 393 with acc: 394 pass 395 z = 4. * x 396 self.assertIsNone(acc.jvp(z)) 397 with acc: 398 yy = y * y 399 self.assertAllClose(6. * -8. * 2., acc.jvp(yy)) 400 401 @test_util.assert_no_new_pyobjects_executing_eagerly 402 def testDeadTensorsJVPCleared(self): 403 x = array_ops.ones([100]) 404 x_weak = weakref.ref(x) 405 grad_tensor = constant_op.constant(array_ops.zeros([100])) 406 grad_tensor_weak = weakref.ref(grad_tensor) 407 with forwardprop.ForwardAccumulator(x, grad_tensor) as acc: 408 derived_tensor = constant_op.constant(2.) * x 409 del grad_tensor 410 self.assertAllClose(array_ops.zeros([100]), acc.jvp(x)) 411 del x 412 self.assertIsNone(x_weak()) 413 self.assertIsNone(grad_tensor_weak()) 414 derived_tensor_weak = weakref.ref(derived_tensor) 415 derived_tensor_grad = acc.jvp(derived_tensor) 416 derived_tensor_grad_weak = weakref.ref(derived_tensor_grad) 417 del derived_tensor 418 del derived_tensor_grad 419 self.assertIsNone(derived_tensor_weak()) 420 self.assertIsNone(derived_tensor_grad_weak()) 421 422 @test_util.assert_no_new_pyobjects_executing_eagerly 423 def testJVPManual(self): 424 primal, tangent = _jvp(math_ops.sin, (constant_op.constant(0.1),), 425 (constant_op.constant(0.2),)) 426 self.assertAllClose(math_ops.sin(0.1), primal) 427 self.assertAllClose(math_ops.cos(0.1) * 0.2, tangent) 428 429 @test_util.assert_no_new_pyobjects_executing_eagerly 430 def testNumericHigherOrder(self): 431 432 def f(x): 433 pointwise = math_ops.sin(x) * math_ops.tan(x) 434 return math_ops.reduce_prod( 435 pointwise + math_ops.reduce_sum(pointwise), axis=1) 436 437 _test_gradients( 438 self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3) 439 440 @test_util.assert_no_new_pyobjects_executing_eagerly 441 def testCustomGradient(self): 442 443 @custom_gradient.custom_gradient 444 def f(x): 445 446 def grad(dy): 447 return dy * math_ops.cos(x) 448 449 return np.sin(x.numpy()), grad 450 451 _test_gradients(self, f, [constant_op.constant([1., 2.])], order=3) 452 453 # TODO(allenl): investigate why assert_no_new_pyobjects_executing_eagerly 454 # fails around this test? 455 def testExceptionCustomGradientRecomputeGradForward(self): 456 457 @custom_gradient.recompute_grad 458 def f(x): 459 return math_ops.reduce_prod(math_ops.tanh(x)**2) 460 461 with self.assertRaisesRegex(NotImplementedError, 462 "recompute_grad tried to transpose"): 463 primals = [constant_op.constant([1.])] 464 sym_jac_fwd = _jacfwd(f, primals) 465 466 def testExceptionInCustomGradientNotSwallowed(self): 467 468 @custom_gradient.custom_gradient 469 def f(unused_x): 470 471 def grad(unused_dy): 472 raise ValueError("test_error_string") 473 474 return 1., grad 475 476 c = constant_op.constant(1.) 477 d = constant_op.constant(2.) 478 with forwardprop.ForwardAccumulator(c, d): 479 with self.assertRaisesRegex(ValueError, "test_error_string"): 480 f(c) 481 482 @parameterized.named_parameters([("EluM5", -0.5, nn_ops.elu), 483 ("EluP5", [0.5], nn_ops.elu), 484 ("SwishP5", 0.5, nn_impl.swish), 485 ("SwishM5", [-0.5], nn_impl.swish)]) 486 def testElementwiseNNOps(self, value, op_fn): 487 _test_gradients(self, op_fn, [constant_op.constant(value)], order=3) 488 489 def testFusedBatchNormGradsInference(self): 490 491 x_shape = [4, 10, 10, 2] 492 increment = 3. / math_ops.reduce_prod( 493 constant_op.constant(x_shape, dtype=dtypes.float32)) 494 x = array_ops.reshape(math_ops.range(-2., 1., increment), x_shape) 495 scale = constant_op.constant([1., 1.1]) 496 offset = constant_op.constant([-0.5, -0.6]) 497 mean = constant_op.constant([-1.3, 1.4]) 498 variance = constant_op.constant([0.7, 0.9]) 499 epsilon = 0.001 500 501 def _bn_fused(x_arg, scale_arg, offset_arg): 502 return nn_impl.fused_batch_norm( 503 x_arg, 504 scale_arg, 505 offset_arg, 506 mean, 507 variance, 508 epsilon=epsilon, 509 is_training=False)[0] 510 511 _test_gradients(self, _bn_fused, [x, scale, offset], order=2, atol=1e-2) 512 513 def testPushPopAccumulatorState(self): 514 # Note that this example is somewhat contrived. push_forwardprop_state is 515 # probably only useful in practice for building functions that compute jvps 516 # alongside their usual outputs. 517 c = constant_op.constant(1.) 518 d = constant_op.constant(2.) 519 with forwardprop.ForwardAccumulator(c, d) as acc: 520 521 @custom_gradient.custom_gradient 522 def f(x): 523 y = math_ops.sin(x.numpy()) 524 525 def grad(dy): 526 with forwardprop_util.push_forwardprop_state(): 527 x_copy = constant_op.constant(x.numpy()) 528 acc._watch(x_copy, dy) 529 y_copy = math_ops.sin(x_copy) 530 return dy * acc.jvp(y_copy) 531 532 return y, grad 533 534 output = f(c) 535 self.assertAllClose(d * math_ops.cos(c), acc.jvp(output)) 536 537 @parameterized.named_parameters([ 538 ("Order{}".format(order), order, expected) 539 for order, expected in enumerate(_X11_35_DERIVATIVES) 540 ]) 541 @test_util.assert_no_new_pyobjects_executing_eagerly 542 def testHigherOrderPureForward(self, order, expected): 543 544 def _forwardgrad(f): 545 546 def _compute_forwardgrad(primal): 547 tangent = constant_op.constant(1.) 548 with forwardprop.ForwardAccumulator(primal, tangent) as acc: 549 primal_out = f(primal) 550 return acc.jvp(primal_out) 551 552 return _compute_forwardgrad 553 554 def _forward(x): 555 return x**3.5 556 557 f = _forward 558 primal = constant_op.constant(1.1) 559 for _ in range(order): 560 f = _forwardgrad(f) 561 self.assertAllClose(expected, f(primal)) 562 563 @parameterized.named_parameters([("Function", def_function.function), 564 ("NoFunction", lambda f: f)]) 565 def testGradPureForward(self, decorator): 566 567 @decorator 568 def f(x): 569 return x**3.5 570 571 primal = constant_op.constant(1.1) 572 with forwardprop.ForwardAccumulator(primal, 573 constant_op.constant(1.)) as outer_acc: 574 with forwardprop.ForwardAccumulator(primal, 575 constant_op.constant(1.)) as acc: 576 primal_out = f(primal) 577 inner_jvp = acc.jvp(primal_out) 578 outer_jvp = outer_acc.jvp(inner_jvp) 579 self.assertAllClose(1.1**3.5, primal_out) 580 self.assertAllClose(3.5 * 1.1**2.5, inner_jvp) 581 self.assertAllClose(3.5 * 2.5 * 1.1**1.5, outer_jvp) 582 self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out))) 583 584 @test_util.assert_no_new_pyobjects_executing_eagerly 585 def testJVPPacking(self): 586 two = constant_op.constant(2.) 587 primal_in = constant_op.constant(1.) 588 inner_jvp = constant_op.constant(3.) 589 with forwardprop.ForwardAccumulator( 590 [primal_in, inner_jvp], 591 [constant_op.constant(2.), 592 constant_op.constant(4.)]) as outer_acc: 593 with forwardprop.ForwardAccumulator(primal_in, inner_jvp) as inner_acc: 594 packed_input_indices, packed_input_tangents = ( 595 forwardprop_util.pack_tangents([primal_in])) 596 self.assertAllClose([3., 2., 4.], packed_input_tangents) 597 expected_indices = ( 598 # inner_acc watches primal_in 599 ( 600 (0, 1),), 601 # outer_acc watches primal_in and inner_jvp 602 ((0, 2), (1, 3))) 603 self.assertAllEqual(expected_indices, packed_input_indices) 604 primal_out = primal_in * two 605 self.assertAllClose(6., inner_acc.jvp(primal_out)) 606 self.assertAllClose(4., outer_acc.jvp(primal_out)) 607 self.assertAllClose(8., outer_acc.jvp(inner_acc.jvp(primal_out))) 608 packed_output_indices, packed_output_tangents = ( 609 forwardprop_util.pack_tangents([primal_out])) 610 self.assertAllClose([6., 4., 8.], packed_output_tangents) 611 self.assertAllEqual(expected_indices, packed_output_indices) 612 613 def testFunctionGradInFunctionPureForward(self): 614 615 @def_function.function 616 def take_gradients(): 617 618 @def_function.function 619 def f(x): 620 return x**3.5 621 622 primal = constant_op.constant(1.1) 623 with forwardprop.ForwardAccumulator( 624 primal, constant_op.constant(1.)) as outer_acc: 625 with forwardprop.ForwardAccumulator(primal, 626 constant_op.constant(1.)) as acc: 627 primal_out = f(primal) 628 inner_jvp = acc.jvp(primal_out) 629 outer_jvp = outer_acc.jvp(inner_jvp) 630 self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out))) 631 return primal_out, inner_jvp, outer_jvp 632 633 primal_out, inner_jvp, outer_jvp = take_gradients() 634 self.assertAllClose(1.1**3.5, primal_out) 635 self.assertAllClose(3.5 * 1.1**2.5, inner_jvp) 636 self.assertAllClose(3.5 * 2.5 * 1.1**1.5, outer_jvp) 637 638 def testFunctionGrad(self): 639 640 @def_function.function 641 def f(x): 642 return math_ops.reduce_prod(math_ops.tanh(x)**2) 643 644 _test_gradients(self, f, [constant_op.constant([1., 2.])], order=3) 645 646 def testReusingJVP(self): 647 m1 = random_ops.random_uniform((256, 2096)) 648 m2 = array_ops.identity(m1) 649 tangent1 = random_ops.random_uniform((256, 2096)) 650 tangent2 = random_ops.random_uniform((256, 2096)) 651 matmul = def_function.function(math_ops.matmul) 652 653 with forwardprop.ForwardAccumulator( 654 primals=[m1, m2], tangents=[tangent1, tangent2]) as acc: 655 result1 = matmul(m1, m1, transpose_b=True) 656 result2 = matmul(m2, m2, transpose_b=True) 657 658 def _expected(mat, tangent): 659 return (math_ops.matmul(tangent, mat, transpose_b=True) + 660 math_ops.matmul(mat, tangent, transpose_b=True)) 661 662 self.assertAllClose(result1, result2) 663 self.assertAllClose(_expected(m1, tangent1), acc.jvp(result1)) 664 self.assertAllClose(_expected(m2, tangent2), acc.jvp(result2)) 665 666 @test_util.assert_no_new_pyobjects_executing_eagerly 667 def testHVPMemory(self): 668 669 def fun(x): 670 return math_ops.reduce_prod(math_ops.tanh(x)**2) 671 672 primals = constant_op.constant([1., 2., 3.]) 673 tangents = constant_op.constant([3., 4., 5.]) 674 _hvp(fun, (primals,), (tangents,)) 675 676 @test_util.assert_no_new_pyobjects_executing_eagerly 677 def testHVPCorrectness(self): 678 679 def fun(x): 680 return math_ops.reduce_prod(math_ops.tanh(x)**2) 681 682 primals = constant_op.constant([1., 2., 3.]) 683 tangents = constant_op.constant([3., 4., 5.]) 684 forwardback_hvp_eager, = _hvp(fun, (primals,), (tangents,)) 685 forwardback_hvp_function, = def_function.function(_hvp)(fun, (primals,), 686 (tangents,)) 687 688 with backprop.GradientTape(persistent=True) as g: 689 g.watch(primals) 690 with backprop.GradientTape() as gg: 691 gg.watch(primals) 692 out = fun(primals) 693 grad = array_ops.unstack(gg.gradient(out, primals)) 694 hessian = [] 695 for i in range(3): 696 hessian.append(g.gradient(grad[i], primals)) 697 hessian = array_ops.stack(hessian, axis=0) 698 backback_hvp = math_ops.tensordot(hessian, tangents, axes=1) 699 700 self.assertAllClose(backback_hvp, forwardback_hvp_eager) 701 self.assertAllClose(backback_hvp, forwardback_hvp_function) 702 703 @test_util.assert_no_new_pyobjects_executing_eagerly 704 def testShouldRecordAndStopRecord(self): 705 c = constant_op.constant(1.) 706 c_tangent = constant_op.constant(2.) 707 with forwardprop.ForwardAccumulator(c, c_tangent) as acc: 708 with backprop.GradientTape() as tape: 709 self.assertFalse(tape_lib.should_record_backprop([c])) 710 self.assertEqual(1, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) 711 tape.watch(c) 712 self.assertEqual(2, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) 713 self.assertTrue(tape_lib.should_record_backprop([c])) 714 with tape_lib.stop_recording(): 715 self.assertEqual(0, 716 pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) 717 self.assertFalse(tape_lib.should_record_backprop([c])) 718 d = c * 2. 719 self.assertEqual(2, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) 720 self.assertTrue(tape_lib.should_record_backprop([c])) 721 self.assertFalse(tape_lib.should_record_backprop([d])) 722 self.assertIsNone(acc.jvp(d)) 723 self.assertIsNone(tape.gradient(d, c)) 724 725 @test_util.assert_no_new_pyobjects_executing_eagerly 726 def testRecordingSelectively(self): 727 c = constant_op.constant(1.) 728 c_tangent = constant_op.constant(2.) 729 with forwardprop.ForwardAccumulator(c, c_tangent) as acc: 730 with backprop.GradientTape(persistent=True) as tape: 731 tape.watch(c) 732 with tape_lib.stop_recording(): 733 two = constant_op.constant(2.) 734 d = c * two 735 three = constant_op.constant(3.) 736 e = c * three 737 self.assertIsNone(acc.jvp(d)) 738 self.assertIsNone(acc.jvp(e)) 739 self.assertIsNone(tape.gradient(d, c)) 740 self.assertIsNone(tape.gradient(e, c)) 741 tape_lib.record_operation_forwardprop_only( 742 "CustomForwardMul", [d], [c, two], lambda dd: (two * dd, c * dd), 743 None) 744 tape_lib.record_operation_backprop_only("CustomBackwardMul", [e], 745 [c, three], lambda de: 746 (three * de, c * de)) 747 self.assertAllClose(4., acc.jvp(d)) 748 self.assertIsNone(acc.jvp(e)) 749 self.assertIsNone(tape.gradient(d, c)) 750 self.assertAllClose(3., tape.gradient(e, c)) 751 752 @test_util.assert_no_new_pyobjects_executing_eagerly 753 def testOpWithNoTrainableOutputs(self): 754 v = variables.Variable(1.) 755 with forwardprop.ForwardAccumulator(v, 11.): 756 v.assign_sub(0.5) 757 self.assertAllClose(0.5, self.evaluate(v)) 758 759 # TODO(b/141025187): Add a no_new_pyobjects decorator. 760 def testVariableReadInFunction(self): 761 v = variables.Variable(1.) 762 with forwardprop.ForwardAccumulator(v, 11.) as acc: 763 764 @def_function.function 765 def f(): 766 return v.read_value(), 2. * v.read_value() 767 768 result = f() 769 self.assertAllClose((1.0, 2.), result) 770 self.assertAllClose((11., 22.), acc.jvp(result)) 771 772 @parameterized.named_parameters([("ForwardPropFirst", True), 773 ("TapeFirst", False)]) 774 def testForwardOverBackwardMemoryEfficiency(self, forward_prop_first): 775 # Watching depends on nesting, not creation order 776 c = constant_op.constant(1.) 777 if forward_prop_first: 778 forward_accumulator = forwardprop.ForwardAccumulator(c, .1) 779 gradient_tape = backprop.GradientTape() 780 else: 781 gradient_tape = backprop.GradientTape() 782 forward_accumulator = forwardprop.ForwardAccumulator(c, .1) 783 try: 784 gc.disable() 785 with gradient_tape as tape: 786 # Adding and removing the tape multiple times in different nesting 787 # patterns does not affect watch ordering. 788 pass 789 with forward_accumulator as acc: 790 with gradient_tape as tape: 791 tape.watch(c) 792 d = math_ops.cos(c) 793 self.assertFalse(tape_lib.should_record_backprop((acc.jvp(d),))) 794 e = math_ops.cos(acc.jvp(d)) 795 math_ops.cos(e) 796 weak_e = weakref.ref(e) 797 del e 798 self.assertIsNone(weak_e()) 799 self.assertIsNone(tape.gradient(acc.jvp(d), c)) 800 finally: 801 gc.enable() 802 803 @parameterized.named_parameters([("ForwardPropFirst", True), 804 ("TapeFirst", False)]) 805 def testBackwardOverForward(self, forward_prop_first): 806 c = constant_op.constant(1.) 807 # Watching depends on nesting, not creation order 808 if forward_prop_first: 809 forward_accumulator = forwardprop.ForwardAccumulator(c, .1) 810 gradient_tape = backprop.GradientTape() 811 else: 812 gradient_tape = backprop.GradientTape() 813 forward_accumulator = forwardprop.ForwardAccumulator(c, .1) 814 with gradient_tape as tape: 815 with forward_accumulator as acc: 816 tape.watch(c) 817 d = math_ops.cos(c) 818 self.assertTrue(tape_lib.should_record_backprop((acc.jvp(d),))) 819 self.assertAllClose(-.1 * math_ops.cos(1.), tape.gradient(acc.jvp(d), c)) 820 821 @test_util.assert_no_new_pyobjects_executing_eagerly 822 def testRecordingWithJVPIndices(self): 823 c = constant_op.constant(1.) 824 with forwardprop.ForwardAccumulator(c, 10.) as acc: 825 packed_input_tangents = forwardprop_util.pack_tangents([c]).tangents 826 self.assertAllClose([10.], packed_input_tangents) 827 d = constant_op.constant(2.) 828 d_tangent = constant_op.constant(3.) 829 tape_lib.record_operation_forwardprop_only("FunctionWithInlineJVPs", 830 [d] + [d_tangent], 831 [c] + packed_input_tangents, 832 None, (((0, 1),),)) 833 self.assertAllClose(3., acc.jvp(d)) 834 835 @test_util.assert_no_new_pyobjects_executing_eagerly 836 def testSpecialForwardFunctionUsed(self): 837 c = constant_op.constant(1.) 838 d = constant_op.constant(2.) 839 e = constant_op.constant(3.) 840 with forwardprop.ForwardAccumulator(c, 10.) as acc: 841 tape_lib.record_operation("ForwardIsSpecial", [d], [c], None, 842 lambda jvp: [-2. * jvp]) 843 self.assertAllClose(-20., acc.jvp(d)) 844 tape_lib.record_operation("ForwardIsSpecial2", [], [], None, lambda: []) 845 tape_lib.record_operation("ForwardIsSpecial3", [e], [d], None, 846 lambda x: [x]) 847 self.assertAllClose(-20., acc.jvp(e)) 848 849 @test_util.assert_no_new_pyobjects_executing_eagerly 850 def testVariableWatched(self): 851 v = variables.Variable([1., 2., 3.]) 852 with forwardprop.ForwardAccumulator(v, constant_op.constant([.1, -.2, 853 .3])) as acc: 854 self.assertAllClose([.1, -.2, .3], acc.jvp(v)) 855 x = v * 2. 856 self.assertAllClose([.2, -.4, .6], acc.jvp(x)) 857 x2 = v + .1 858 self.assertAllClose([.1, -.2, .3], acc.jvp(x2)) 859 860 def testUnconnectedGradients(self): 861 x = constant_op.constant(-1.) 862 with forwardprop.ForwardAccumulator(x, 0.1) as acc: 863 self.assertAllClose(0.1, acc.jvp(x, unconnected_gradients="zero")) 864 self.assertAllClose(0.1, acc.jvp(x, unconnected_gradients="none")) 865 y = constant_op.constant(-2.) 866 self.assertAllClose(0.0, acc.jvp(y, unconnected_gradients="zero")) 867 self.assertIsNone(acc.jvp(y, unconnected_gradients="none")) 868 869 # TODO(kkb): One weakref instance is created with warmup_iters=2, 870 # investigate. 871 @test_util.assert_no_new_pyobjects_executing_eagerly(warmup_iters=3) 872 def testVariableWatchedFunction(self): 873 874 class _Model(module.Module): 875 876 def __init__(self): 877 self._v = None 878 879 @def_function.function 880 def compute_jvps(self): 881 if self._v is None: 882 self._v = variables.Variable([1., 2., 3.]) 883 with forwardprop.ForwardAccumulator(self._v, 884 constant_op.constant([.1, -.2, 885 .3])) as acc: 886 x = self._v * 2. 887 x2 = self._v + .1 888 return acc.jvp((self._v, x, x2)) 889 890 model = _Model() 891 v_jvp, x_jvp, x2_jvp = model.compute_jvps() 892 self.assertAllClose([.1, -.2, .3], v_jvp) 893 self.assertAllClose([.2, -.4, .6], x_jvp) 894 self.assertAllClose([.1, -.2, .3], x2_jvp) 895 896 def testIndexSlicesGrad(self): 897 x = constant_op.constant([1.]) 898 899 with forwardprop.ForwardAccumulator(x, constant_op.constant([3.])) as acc: 900 y = array_ops.gather(x, 0) 901 self.assertAllClose(3., acc.jvp(y)) 902 903 def testIndexSlicesGradInFunction(self): 904 905 @def_function.function 906 def f(a): 907 return array_ops.gather(a, 0) 908 909 x = constant_op.constant([1.]) 910 911 with forwardprop.ForwardAccumulator(x, constant_op.constant([3.])) as acc: 912 y = f(x) 913 self.assertAllClose(3., acc.jvp(y)) 914 915 # NOTE: assert_no_new_pyobjects_executing_eagerly fails flakily on this 916 # test... could be something wrong with the test decorator, or some sort of 917 # nondeterministic caching. 918 def testMirroredVariableWatched(self): 919 920 def _replicated(input_tangent): 921 with forwardprop.ForwardAccumulator(v, input_tangent) as acc: 922 self.assertAllClose([.1, -.2, .3], acc.jvp(v)) 923 x = v * 2. 924 self.assertAllClose([.2, -.4, .6], acc.jvp(x)) 925 x2 = v + .1 926 self.assertAllClose([.1, -.2, .3], acc.jvp(x2)) 927 928 strategy = mirrored_strategy.MirroredStrategy() 929 with strategy.scope(): 930 v = variables.Variable([1., 2., 3.]) 931 strategy.run(_replicated, args=(constant_op.constant([.1, -.2, .3]),)) 932 933 # TODO(b/141025187): Add a no_new_pyobjects decorator. 934 def testArgumentUnused(self): 935 v = constant_op.constant(1.) 936 with forwardprop.ForwardAccumulator(v, 11.) as acc: 937 938 @def_function.function 939 def _f(x): 940 del x 941 return constant_op.constant(1.) 942 943 result = _f(v) 944 self.assertAllClose(1.0, result) 945 self.assertIsNone(acc.jvp(result)) 946 947 948@def_function.function 949def _has_loop(iters, y): 950 ret = 0. 951 for i in math_ops.range(iters): 952 ret += y * math_ops.cast(i, dtypes.float32) 953 return ret 954 955 956@def_function.function 957def _has_cond(k, y): 958 if k > 1: 959 ret = 3. * y 960 else: 961 ret = 0. 962 return ret 963 964 965@def_function.function 966def _fprop_while(iters, y): 967 with forwardprop.ForwardAccumulator(y, 1.) as acc: 968 ret = 0. 969 for i in math_ops.range(iters): 970 ret += y * math_ops.cast(i, dtypes.float32) 971 return acc.jvp(ret) 972 973 974@def_function.function 975def _fprop_cond(k, y): 976 with forwardprop.ForwardAccumulator(y, 1.) as acc: 977 if k > 1: 978 ret = 3. * y 979 else: 980 ret = 0. 981 return acc.jvp(ret) 982 983 984class ControlFlowTests(test.TestCase): 985 986 @test_util.assert_no_new_pyobjects_executing_eagerly 987 def testOfFunctionWhile(self): 988 y = constant_op.constant(1.) 989 with forwardprop.ForwardAccumulator(y, 1.) as acc: 990 self.assertAllClose(10., acc.jvp(_has_loop(constant_op.constant(5), y))) 991 992 @test_util.assert_no_new_pyobjects_executing_eagerly 993 def testOfFunctionCond(self): 994 y = constant_op.constant(1.) 995 with forwardprop.ForwardAccumulator(y, 1.) as acc: 996 self.assertAllClose(3., acc.jvp(_has_cond(constant_op.constant(5), y))) 997 self.assertAllClose(0., acc.jvp(_has_cond(constant_op.constant(0), y))) 998 999 @test_util.assert_no_new_pyobjects_executing_eagerly 1000 def testInFunctionWhile(self): 1001 self.assertAllClose( 1002 10., _fprop_while(constant_op.constant(5), constant_op.constant(1.))) 1003 1004 @test_util.assert_no_new_pyobjects_executing_eagerly 1005 def testInFunctionCond(self): 1006 self.assertAllClose( 1007 3., _fprop_cond(constant_op.constant(5), constant_op.constant(1.))) 1008 self.assertAllClose( 1009 0., _fprop_cond(constant_op.constant(0), constant_op.constant(1.))) 1010 1011 1012class HessianTests(test.TestCase, parameterized.TestCase): 1013 1014 def testHessian1D(self): 1015 # Note: stolen from ops/gradients_test.py 1016 m = 4 1017 rng = np.random.RandomState([1, 2, 3]) 1018 mat_value = rng.randn(m, m).astype("float32") 1019 x_value = rng.randn(m).astype("float32") 1020 hess_value = mat_value + mat_value.T 1021 mat = variables.Variable(mat_value) 1022 1023 def _f(x): 1024 return math_ops.reduce_sum(x[:, None] * mat * x[None, :]) 1025 1026 hessian_eager, = _forward_over_back_hessian( 1027 _f, [constant_op.constant(x_value)], 1028 use_pfor=False, 1029 dtype=[dtypes.float32]) 1030 self.assertAllClose(hess_value, hessian_eager) 1031 hessian_function, = def_function.function(_forward_over_back_hessian)( 1032 _f, [constant_op.constant(x_value)], 1033 use_pfor=False, 1034 dtype=[dtypes.float32]) 1035 self.assertAllClose(hess_value, hessian_function) 1036 hessian_pfor, = def_function.function(_forward_over_back_hessian)( 1037 _f, [constant_op.constant(x_value)], 1038 use_pfor=True, 1039 dtype=[dtypes.float32]) 1040 self.assertAllClose(hess_value, hessian_pfor) 1041 1042 1043class BatchTests(test.TestCase, parameterized.TestCase): 1044 1045 @parameterized.parameters([(math_ops.sin, (2, 3), 5), 1046 (math_ops.sin, (2, 3, 4), 10)]) 1047 def testJVPBatchCorrectness(self, f, primal_shape, batch_size): 1048 primals = [random_ops.random_uniform(primal_shape)] 1049 tangent_batch = [random_ops.random_uniform([batch_size, *primal_shape])] 1050 self.assertAllClose( 1051 _jvp_batch(f, primals, tangent_batch)[1], 1052 _jvp_batch_matmul(f, primals, *tangent_batch)) 1053 1054 def testBatchCorrectness(self): 1055 x = constant_op.constant(2.0) 1056 y = constant_op.constant(5.0) 1057 tangents = ( 1058 constant_op.constant([1., 0., 1.]), 1059 constant_op.constant([0., 1., 1.]), 1060 ) 1061 with forwardprop.ForwardAccumulator._batch_accumulator((x, y), 1062 tangents) as acc: 1063 z = x * y 1064 self.assertAllClose(acc.jvp(z), constant_op.constant([5.0, 2.0, 7.0])) 1065 1066 @parameterized.named_parameters([("ForwardPropFirst", True), 1067 ("TapeFirst", False)]) 1068 def testBatchBackwardOverForward(self, forward_prop_first): 1069 x = constant_op.constant(1.) 1070 tangents = random_ops.random_normal(shape=[10], seed=1) 1071 expected = [-t * math_ops.cos(1.) for t in tangents] 1072 if forward_prop_first: 1073 batch_acc = forwardprop.ForwardAccumulator._batch_accumulator(x, tangents) 1074 gradient_tape = backprop.GradientTape(persistent=True) 1075 else: 1076 gradient_tape = backprop.GradientTape(persistent=True) 1077 batch_acc = forwardprop.ForwardAccumulator._batch_accumulator(x, tangents) 1078 with gradient_tape as tape: 1079 with batch_acc as acc: 1080 tape.watch(x) 1081 y = math_ops.cos(x) 1082 self.assertTrue(tape_lib.should_record_backprop((acc.jvp(y),))) 1083 jvps = acc.jvp(y) 1084 d2y_dx2 = [tape.gradient(dy_dx, x) for dy_dx in jvps] 1085 self.assertAllClose(expected, d2y_dx2) 1086 1087 1088if __name__ == "__main__": 1089 # TODO(allenl): Also test with 1.x-style graph mode. 1090 ops.enable_eager_execution() 1091 test.main() 1092