1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import functools 17import gc 18import weakref 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python import pywrap_tfe 24from tensorflow.python.distribute import mirrored_strategy 25from tensorflow.python.eager import backprop 26from tensorflow.python.eager import context 27from tensorflow.python.eager import def_function 28from tensorflow.python.eager import forwardprop 29from tensorflow.python.eager import forwardprop_util 30from tensorflow.python.eager import tape as tape_lib 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import test_util 35from tensorflow.python.module import module 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import custom_gradient 38from tensorflow.python.ops import gradient_checker_v2 39from tensorflow.python.ops import map_fn 40from tensorflow.python.ops import math_ops 41from tensorflow.python.ops import nn_impl 42from tensorflow.python.ops import nn_ops 43from tensorflow.python.ops import random_ops 44from tensorflow.python.ops import variables 45from tensorflow.python.ops.parallel_for import control_flow_ops 46from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 47from tensorflow.python.platform import test 48from tensorflow.python.util import nest 49 50_X11_35_DERIVATIVES = [ 51 1.1**3.5, 3.5 * 1.1**2.5, 3.5 * 2.5 * 1.1**1.5, 3.5 * 2.5 * 1.5 * 1.1**0.5 52] 53 54 55# TODO(allenl): Move this somewhere useful once forward gradients are stable. 56def _jvp(f, primals, tangents): 57 """Compute the jacobian of `f` at `primals` multiplied by `tangents`.""" 58 with forwardprop.ForwardAccumulator(primals, tangents) as acc: 59 primals_out = f(*primals) 60 return primals_out, acc.jvp( 61 primals_out, unconnected_gradients=UnconnectedGradients.ZERO) 62 63 64def _jacfwd(f, primals): 65 """Compute the jacobian of `f` at `primals` using forward-mode autodiff.""" 66 jac_flat = [] 67 flat_primals = nest.flatten(primals) 68 tangent_mask = [ 69 array_ops.zeros_like(primal, dtype=primal.dtype) 70 for primal in flat_primals 71 ] 72 for primal_index, primal in enumerate(flat_primals): 73 primal_vector = array_ops.reshape(primal, [-1]) 74 primal_vector_length = array_ops.size(primal_vector) 75 jac_columns = [] 76 for element_index in math_ops.range(primal_vector_length): 77 mask = array_ops.one_hot( 78 element_index, primal_vector_length, dtype=primal.dtype) 79 tangent_mask[primal_index] = array_ops.reshape(mask, 80 array_ops.shape(primal)) 81 jac_columns.append( 82 nest.map_structure( 83 functools.partial(array_ops.reshape, shape=[-1]), 84 _jvp(f, primals, nest.pack_sequence_as(primals, 85 tangent_mask))[1])) 86 jac_flat.append(array_ops.stack(jac_columns, axis=1)) 87 tangent_mask[primal_index] = array_ops.zeros_like(primal) 88 return nest.pack_sequence_as(primals, jac_flat) 89 90 91def _jvp_batch(f, primal, tangents): 92 tf_function = def_function.function(f) 93 94 return control_flow_ops.vectorized_map( 95 functools.partial(_jvp, tf_function, primal), tangents) 96 97 98def _jvp_batch_matmul(f, primals, tangent_batch): 99 """Compute the jacobian of `f` at `primals` multiplied by `tangents`.""" 100 jac_fwd = _jacfwd(f, primals) 101 102 def jac_mul(tangent): 103 flat_tangent = array_ops.reshape(tangent, shape=[-1]) 104 tangent_vector = array_ops.expand_dims(flat_tangent, 1) 105 jvp_vector = math_ops.matmul(jac_fwd, tangent_vector) 106 return array_ops.reshape(jvp_vector, tangent.shape) 107 108 return control_flow_ops.vectorized_map(jac_mul, tangent_batch) 109 110 111def _grad(f, argnums=0): 112 """Return a function which computes the gradient of `f`.""" 113 114 def _f(*params): 115 with backprop.GradientTape() as tape: 116 tape.watch(params) 117 primals_out = f(*params) 118 return tape.gradient( 119 primals_out, 120 params[argnums], 121 unconnected_gradients=UnconnectedGradients.ZERO) 122 123 return _f 124 125 126def _gradfwd(f, argnums=0, f_out_dtypes=dtypes.float32): 127 """Return a function which computes the gradient of `f` in forward mode.""" 128 129 def _f(*params): 130 131 def _single_jvp(param_mask): 132 with forwardprop.ForwardAccumulator( 133 primals=[params[argnums]], tangents=param_mask) as acc: 134 primals_out = f(*params) 135 return acc.jvp(primals_out) 136 137 # Building up a function to run with pfor takes a bit too long since we're 138 # only running it a handful of times. 139 return _vectorize_parameters( 140 _single_jvp, [params[argnums]], use_pfor=False, dtype=f_out_dtypes) 141 142 return _f 143 144 145def _hvp(f, primals, tangents): 146 """Compute a forward-over-back Hessian-vector product.""" 147 with forwardprop.ForwardAccumulator(primals, tangents) as acc: 148 with backprop.GradientTape() as tape: 149 tape.watch(primals) 150 f_out = f(*primals) 151 f_out.shape.assert_is_compatible_with([]) 152 return acc.jvp(tape.gradient(f_out, primals)) 153 154 155def _vectorize_parameters(f, params, use_pfor, dtype): 156 """Loop over `params`, providing a one-hot mask to `f` for each.""" 157 parameter_sizes = [array_ops.size(param) for param in params] 158 total_size = math_ops.add_n(parameter_sizes) 159 160 def _wrapper(index): 161 full_onehot = array_ops.one_hot(index, total_size) 162 split_onehot = array_ops.split(full_onehot, parameter_sizes) 163 tangents = [ 164 array_ops.reshape(v, array_ops.shape(param)) 165 for param, v in zip(params, split_onehot) 166 ] 167 return f(tangents) 168 169 if use_pfor: 170 return control_flow_ops.vectorized_map(_wrapper, math_ops.range(total_size)) 171 172 return map_fn.map_fn(_wrapper, math_ops.range(total_size), dtype) 173 174 175def _forward_over_back_hessian(f, params, use_pfor, dtype=None): 176 """Computes the full Hessian matrix for the scalar-valued f(*params). 177 178 Args: 179 f: A function taking `params` and returning a scalar. 180 params: A possibly nested structure of tensors. 181 use_pfor: If true, uses `tf.vectorized_map` calls instead of looping. 182 dtype: Required if `use_pfor=False`. A possibly nested structure of dtypes 183 (e.g. `tf.float32`) matching the structure of `f`'s returns. 184 185 Returns: 186 A possibly nested structure of matrix slices corresponding to `params`. Each 187 slice has shape [P, p_s] where `p_s` is the number of parameters (`tf.size`) 188 in the corresponding element of `params` and `P` is the total number of 189 parameters (`sum_s(p_s)`). The full matrix can be obtained by concatenating 190 along the second axis. 191 """ 192 return _vectorize_parameters( 193 functools.partial(_hvp, f, params), 194 params, 195 use_pfor=use_pfor, 196 dtype=dtype) 197 198 199def _test_gradients(testcase, 200 f, 201 primals, 202 order, 203 delta=1e-3, 204 rtol=1e-2, 205 atol=1e-6, 206 srtol=1e-6, 207 satol=1e-6): 208 """Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients.""" 209 if order < 1: 210 raise ValueError( 211 "`order` should be a positive integer, got '{}'.".format(order)) 212 if order > 1: 213 _test_gradients( 214 testcase=testcase, 215 f=_grad(f), 216 primals=primals, 217 order=order - 1, 218 delta=delta, 219 rtol=rtol, 220 atol=atol, 221 srtol=srtol, 222 satol=satol) 223 sym_jac_back, num_jac = gradient_checker_v2.compute_gradient( 224 f, primals, delta=delta) 225 testcase.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol) 226 sym_jac_fwd = _jacfwd(f, primals) 227 testcase.assertAllClose(num_jac, sym_jac_fwd, rtol=rtol, atol=atol) 228 # And the symbolic computations should be much closer. 229 testcase.assertAllClose(sym_jac_back, sym_jac_fwd, rtol=srtol, atol=satol) 230 231 232@test_util.with_eager_op_as_function 233class ForwardpropTest(test.TestCase, parameterized.TestCase): 234 235 def testJVPFunction(self): 236 add_outputs = (constant_op.constant(4.),) 237 vp, = forwardprop._jvp_dispatch( 238 op_name="Add", 239 attr_tuple=(), 240 inputs=(constant_op.constant(1.), constant_op.constant(3.)), 241 outputs=add_outputs, 242 tangents=( 243 constant_op.constant(1.), 244 constant_op.constant(5.), 245 )) 246 self.assertAllClose(1. + 5., self.evaluate(vp)) 247 248 mul_outputs = (constant_op.constant([20.]),) 249 vp, = forwardprop._jvp_dispatch( 250 op_name="Mul", 251 attr_tuple=(), 252 inputs=(constant_op.constant([4.]), constant_op.constant([5.])), 253 outputs=mul_outputs, 254 tangents=( 255 constant_op.constant([2.]), 256 constant_op.constant([3.]), 257 )) 258 self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp)) 259 260 def testJVPFunctionWithBatchOfTangents(self): 261 add_outputs = (constant_op.constant(4.),) 262 jvp_flat = forwardprop._jvp_dispatch( 263 op_name="Add", 264 attr_tuple=(), 265 inputs=(constant_op.constant(1.), constant_op.constant(3.)), 266 outputs=add_outputs, 267 tangents=( 268 constant_op.constant([1., 2., 3.]), 269 constant_op.constant([4., 5., 6.]), 270 ), 271 use_batch=True) 272 273 # Using evaluate and asserting with just a list works too 274 # but the output is more explicit this way 275 self.assertAllClose([constant_op.constant([1. + 4., 2. + 5., 3. + 6.])], 276 jvp_flat) 277 278 mul_outputs = (constant_op.constant([20.]),) 279 jvp_flat = forwardprop._jvp_dispatch( 280 op_name="Mul", 281 attr_tuple=(), 282 inputs=(constant_op.constant([4.]), constant_op.constant([5.])), 283 outputs=mul_outputs, 284 tangents=( 285 constant_op.constant([[1.], [0.], [1.]]), 286 constant_op.constant([[0.], [1.], [1.]]), 287 ), 288 use_batch=True) 289 self.assertAllClose([constant_op.constant([[5.], [4.], [5. + 4.]])], 290 jvp_flat) 291 292 def testJVPFunctionRaisesError(self): 293 sum_outputs = (constant_op.constant(6.),) 294 295 with self.assertRaisesRegex(ValueError, r".*was expected to be of shape*"): 296 forwardprop._jvp_dispatch( 297 op_name="Add", 298 attr_tuple=(), 299 inputs=(constant_op.constant(2.), constant_op.constant(4.)), 300 outputs=sum_outputs, 301 tangents=(constant_op.constant([1., 2.]), 302 constant_op.constant([[1.], [2.]])), 303 use_batch=True) 304 305 def testNonDifferentiableOpWithInputTangent(self): 306 x = constant_op.constant(1.) 307 with forwardprop.ForwardAccumulator(x, 2.) as acc1: 308 with forwardprop.ForwardAccumulator(x, 2.) as acc2: 309 y = array_ops.zeros_like(x) 310 self.assertIsNone(acc1.jvp(y)) 311 self.assertIsNone(acc2.jvp(y)) 312 313 def testRunFunctionsEagerly(self): 314 try: 315 original_setting = def_function.functions_run_eagerly() 316 def_function.run_functions_eagerly(True) 317 x = constant_op.constant(1.) 318 with forwardprop.ForwardAccumulator(x, 2.) as acc: 319 y = x * 3. 320 self.assertAllClose(6., acc.jvp(y)) 321 finally: 322 def_function.run_functions_eagerly(original_setting) 323 324 def testJVPFunctionUsedByAccumulatorForOps(self): 325 previous_fn = forwardprop._jvp_dispatch 326 try: 327 x = constant_op.constant(1.) 328 with forwardprop.ForwardAccumulator(x, 2.) as acc: 329 y = x + x 330 pywrap_tfe.TFE_Py_RegisterJVPFunction( 331 lambda *args, **kwargs: [constant_op.constant(-15.)]) 332 z = x + x 333 self.assertAllClose(4., acc.jvp(y)) 334 self.assertAllClose(-15., acc.jvp(z)) 335 finally: 336 pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn) 337 338 @test_util.assert_no_new_pyobjects_executing_eagerly 339 def testFunctionCacheLimited(self): 340 # Every time this loop is executed, it will create a slightly larger Tensor 341 # and push it through Add's gradient. 342 # We run TRACE_COUNT_LIMIT x 2 so that it is tested with both 343 # experimental_relax_shapes on and off. 344 for execution_count in range(forwardprop._TRACE_COUNT_LIMIT*2): 345 x = array_ops.zeros([execution_count]) 346 with forwardprop.ForwardAccumulator(x, array_ops.ones_like(x)) as acc: 347 y = x + x 348 self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y)) 349 350 def testVariableUnwatchedZero(self): 351 v = variables.Variable([[1.]]) 352 x = constant_op.constant(1.) 353 xt = constant_op.constant(2.) 354 with forwardprop.ForwardAccumulator(x, xt) as acc: 355 pass 356 self.assertIsNone(acc.jvp(v)) 357 self.assertAllClose([[0.]], acc.jvp(v, unconnected_gradients="zero")) 358 359 @test_util.assert_no_new_pyobjects_executing_eagerly 360 def testFunctionReturnsResource(self): 361 v = variables.Variable([[1.]]) 362 x = constant_op.constant(1.) 363 xt = constant_op.constant(2.) 364 365 @def_function.function 366 def f(a): 367 return a, v.handle 368 369 with forwardprop.ForwardAccumulator(x, xt) as acc: 370 y, _ = f(x) 371 self.assertAllClose(2., acc.jvp(y)) 372 373 @test_util.assert_no_new_pyobjects_executing_eagerly 374 def testMultipleWatchesAdd(self): 375 x = constant_op.constant(-2.) 376 with self.assertRaisesRegex(ValueError, "multiple times"): 377 with forwardprop.ForwardAccumulator([x, x], [1., 2.]): 378 pass 379 with forwardprop.ForwardAccumulator([x], [3.]) as acc: 380 self.assertAllClose(3., acc.jvp(x)) 381 acc._watch(x, constant_op.constant(10.)) 382 self.assertAllClose(13., acc.jvp(x)) 383 acc._watch(x, constant_op.constant(11.)) 384 self.assertAllClose(24., acc.jvp(x)) 385 y = constant_op.constant(3.) * x 386 self.assertAllClose(24., acc.jvp(x)) 387 self.assertAllClose(24. * 3., acc.jvp(y)) 388 389 @test_util.assert_no_new_pyobjects_executing_eagerly 390 def testReenter(self): 391 x = constant_op.constant(-2.) 392 with forwardprop.ForwardAccumulator(x, 1.5) as acc: 393 self.assertAllClose(1.5, acc.jvp(x)) 394 y = 4. * x 395 self.assertAllClose(6., acc.jvp(y)) 396 with self.assertRaisesRegex(ValueError, "already recording"): 397 with acc: 398 pass 399 z = 4. * x 400 self.assertIsNone(acc.jvp(z)) 401 with acc: 402 yy = y * y 403 self.assertAllClose(6. * -8. * 2., acc.jvp(yy)) 404 405 @test_util.assert_no_new_pyobjects_executing_eagerly 406 def testDeadTensorsJVPCleared(self): 407 x = array_ops.ones([100]) 408 x_weak = weakref.ref(x) 409 grad_tensor = constant_op.constant(array_ops.zeros([100])) 410 grad_tensor_weak = weakref.ref(grad_tensor) 411 with forwardprop.ForwardAccumulator(x, grad_tensor) as acc: 412 derived_tensor = constant_op.constant(2.) * x 413 del grad_tensor 414 self.assertAllClose(array_ops.zeros([100]), acc.jvp(x)) 415 del x 416 self.assertIsNone(x_weak()) 417 self.assertIsNone(grad_tensor_weak()) 418 derived_tensor_weak = weakref.ref(derived_tensor) 419 derived_tensor_grad = acc.jvp(derived_tensor) 420 derived_tensor_grad_weak = weakref.ref(derived_tensor_grad) 421 del derived_tensor 422 del derived_tensor_grad 423 self.assertIsNone(derived_tensor_weak()) 424 self.assertIsNone(derived_tensor_grad_weak()) 425 426 @test_util.assert_no_new_pyobjects_executing_eagerly 427 def testJVPManual(self): 428 primal, tangent = _jvp(math_ops.sin, (constant_op.constant(0.1),), 429 (constant_op.constant(0.2),)) 430 self.assertAllClose(math_ops.sin(0.1), primal) 431 self.assertAllClose(math_ops.cos(0.1) * 0.2, tangent) 432 433 @test_util.assert_no_new_pyobjects_executing_eagerly 434 def testNumericHigherOrder(self): 435 436 def f(x): 437 pointwise = math_ops.sin(x) * math_ops.tan(x) 438 return math_ops.reduce_prod( 439 pointwise + math_ops.reduce_sum(pointwise), axis=1) 440 441 if (context.run_eager_op_as_function_enabled() and 442 test_util.is_xla_enabled()): 443 # Autoclustering kicks in when eager_op_as_function is enabled. 444 # Under XLA the symbolic tolerances are less than under TF. 445 # Ref: b/202559426 446 _test_gradients( 447 self, 448 f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], 449 order=3, 450 srtol=1e-6, 451 satol=1e-3) 452 else: 453 _test_gradients( 454 self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3) 455 456 @test_util.assert_no_new_pyobjects_executing_eagerly 457 def testNumericHigherOrderFloat64(self): 458 459 def f(x): 460 pointwise = math_ops.sin(x) * math_ops.tan(x) 461 return math_ops.reduce_prod( 462 pointwise + math_ops.reduce_sum(pointwise), axis=1) 463 464 _test_gradients( 465 self, 466 f, 467 [constant_op.constant([[2.0, 3.0], [1.0, 4.0]], dtype=dtypes.float64)], 468 order=3) 469 470 @test_util.assert_no_new_pyobjects_executing_eagerly 471 def testCustomGradient(self): 472 473 @custom_gradient.custom_gradient 474 def f(x): 475 476 def grad(dy): 477 return dy * math_ops.cos(x) 478 479 return np.sin(x.numpy()), grad 480 481 _test_gradients(self, f, [constant_op.constant([1., 2.])], order=3) 482 483 # TODO(allenl): investigate why assert_no_new_pyobjects_executing_eagerly 484 # fails around this test? 485 def testExceptionCustomGradientRecomputeGradForward(self): 486 487 @custom_gradient.recompute_grad 488 def f(x): 489 return math_ops.reduce_prod(math_ops.tanh(x)**2) 490 491 with self.assertRaisesRegex(NotImplementedError, 492 "recompute_grad tried to transpose"): 493 primals = [constant_op.constant([1.])] 494 sym_jac_fwd = _jacfwd(f, primals) 495 496 def testExceptionInCustomGradientNotSwallowed(self): 497 498 @custom_gradient.custom_gradient 499 def f(unused_x): 500 501 def grad(unused_dy): 502 raise ValueError("test_error_string") 503 504 return 1., grad 505 506 c = constant_op.constant(1.) 507 d = constant_op.constant(2.) 508 with forwardprop.ForwardAccumulator(c, d): 509 with self.assertRaisesRegex(ValueError, "test_error_string"): 510 f(c) 511 512 @parameterized.named_parameters([("EluM5", -0.5, nn_ops.elu), 513 ("EluP5", [0.5], nn_ops.elu), 514 ("SwishP5", 0.5, nn_impl.swish), 515 ("SwishM5", [-0.5], nn_impl.swish)]) 516 def testElementwiseNNOps(self, value, op_fn): 517 _test_gradients(self, op_fn, [constant_op.constant(value)], order=3) 518 519 def testFusedBatchNormGradsInference(self): 520 521 x_shape = [4, 10, 10, 2] 522 increment = 3. / math_ops.reduce_prod( 523 constant_op.constant(x_shape, dtype=dtypes.float32)) 524 x = array_ops.reshape(math_ops.range(-2., 1., increment), x_shape) 525 scale = constant_op.constant([1., 1.1]) 526 offset = constant_op.constant([-0.5, -0.6]) 527 mean = constant_op.constant([-1.3, 1.4]) 528 variance = constant_op.constant([0.7, 0.9]) 529 epsilon = 0.001 530 531 def _bn_fused(x_arg, scale_arg, offset_arg): 532 return nn_impl.fused_batch_norm( 533 x_arg, 534 scale_arg, 535 offset_arg, 536 mean, 537 variance, 538 epsilon=epsilon, 539 is_training=False)[0] 540 541 _test_gradients(self, _bn_fused, [x, scale, offset], order=2, atol=1e-2) 542 543 def testPushPopAccumulatorState(self): 544 # Note that this example is somewhat contrived. push_forwardprop_state is 545 # probably only useful in practice for building functions that compute jvps 546 # alongside their usual outputs. 547 c = constant_op.constant(1.) 548 d = constant_op.constant(2.) 549 with forwardprop.ForwardAccumulator(c, d) as acc: 550 551 @custom_gradient.custom_gradient 552 def f(x): 553 y = math_ops.sin(x.numpy()) 554 555 def grad(dy): 556 with forwardprop_util.push_forwardprop_state(): 557 x_copy = constant_op.constant(x.numpy()) 558 acc._watch(x_copy, dy) 559 y_copy = math_ops.sin(x_copy) 560 return dy * acc.jvp(y_copy) 561 562 return y, grad 563 564 output = f(c) 565 self.assertAllClose(d * math_ops.cos(c), acc.jvp(output)) 566 567 @parameterized.named_parameters([ 568 ("Order{}".format(order), order, expected) 569 for order, expected in enumerate(_X11_35_DERIVATIVES) 570 ]) 571 @test_util.assert_no_new_pyobjects_executing_eagerly 572 def testHigherOrderPureForward(self, order, expected): 573 574 def _forwardgrad(f): 575 576 def _compute_forwardgrad(primal): 577 tangent = constant_op.constant(1.) 578 with forwardprop.ForwardAccumulator(primal, tangent) as acc: 579 primal_out = f(primal) 580 return acc.jvp(primal_out) 581 582 return _compute_forwardgrad 583 584 def _forward(x): 585 return x**3.5 586 587 f = _forward 588 primal = constant_op.constant(1.1) 589 for _ in range(order): 590 f = _forwardgrad(f) 591 self.assertAllClose(expected, f(primal)) 592 593 @parameterized.named_parameters([("Function", def_function.function), 594 ("NoFunction", lambda f: f)]) 595 def testGradPureForward(self, decorator): 596 597 @decorator 598 def f(x): 599 return x**3.5 600 601 primal = constant_op.constant(1.1) 602 with forwardprop.ForwardAccumulator(primal, 603 constant_op.constant(1.)) as outer_acc: 604 with forwardprop.ForwardAccumulator(primal, 605 constant_op.constant(1.)) as acc: 606 primal_out = f(primal) 607 inner_jvp = acc.jvp(primal_out) 608 outer_jvp = outer_acc.jvp(inner_jvp) 609 self.assertAllClose(1.1**3.5, primal_out) 610 self.assertAllClose(3.5 * 1.1**2.5, inner_jvp) 611 self.assertAllClose(3.5 * 2.5 * 1.1**1.5, outer_jvp) 612 self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out))) 613 614 @test_util.assert_no_new_pyobjects_executing_eagerly 615 def testJVPPacking(self): 616 two = constant_op.constant(2.) 617 primal_in = constant_op.constant(1.) 618 inner_jvp = constant_op.constant(3.) 619 with forwardprop.ForwardAccumulator( 620 [primal_in, inner_jvp], 621 [constant_op.constant(2.), 622 constant_op.constant(4.)]) as outer_acc: 623 with forwardprop.ForwardAccumulator(primal_in, inner_jvp) as inner_acc: 624 packed_input_indices, packed_input_tangents = ( 625 forwardprop_util.pack_tangents([primal_in])) 626 self.assertAllClose([3., 2., 4.], packed_input_tangents) 627 expected_indices = ( 628 # inner_acc watches primal_in 629 ( 630 (0, 1),), 631 # outer_acc watches primal_in and inner_jvp 632 ((0, 2), (1, 3))) 633 self.assertAllEqual(expected_indices, packed_input_indices) 634 primal_out = primal_in * two 635 self.assertAllClose(6., inner_acc.jvp(primal_out)) 636 self.assertAllClose(4., outer_acc.jvp(primal_out)) 637 self.assertAllClose(8., outer_acc.jvp(inner_acc.jvp(primal_out))) 638 packed_output_indices, packed_output_tangents = ( 639 forwardprop_util.pack_tangents([primal_out])) 640 self.assertAllClose([6., 4., 8.], packed_output_tangents) 641 self.assertAllEqual(expected_indices, packed_output_indices) 642 643 def testFunctionGradInFunctionPureForward(self): 644 645 @def_function.function 646 def take_gradients(): 647 648 @def_function.function 649 def f(x): 650 return x**3.5 651 652 primal = constant_op.constant(1.1) 653 with forwardprop.ForwardAccumulator( 654 primal, constant_op.constant(1.)) as outer_acc: 655 with forwardprop.ForwardAccumulator(primal, 656 constant_op.constant(1.)) as acc: 657 primal_out = f(primal) 658 inner_jvp = acc.jvp(primal_out) 659 outer_jvp = outer_acc.jvp(inner_jvp) 660 self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out))) 661 return primal_out, inner_jvp, outer_jvp 662 663 primal_out, inner_jvp, outer_jvp = take_gradients() 664 self.assertAllClose(1.1**3.5, primal_out) 665 self.assertAllClose(3.5 * 1.1**2.5, inner_jvp) 666 self.assertAllClose(3.5 * 2.5 * 1.1**1.5, outer_jvp) 667 668 def testFunctionGrad(self): 669 670 @def_function.function 671 def f(x): 672 return math_ops.reduce_prod(math_ops.tanh(x)**2) 673 674 _test_gradients(self, f, [constant_op.constant([1., 2.])], order=3) 675 676 def testReusingJVP(self): 677 m1 = random_ops.random_uniform((256, 2096)) 678 m2 = array_ops.identity(m1) 679 tangent1 = random_ops.random_uniform((256, 2096)) 680 tangent2 = random_ops.random_uniform((256, 2096)) 681 matmul = def_function.function(math_ops.matmul) 682 683 with forwardprop.ForwardAccumulator( 684 primals=[m1, m2], tangents=[tangent1, tangent2]) as acc: 685 result1 = matmul(m1, m1, transpose_b=True) 686 result2 = matmul(m2, m2, transpose_b=True) 687 688 def _expected(mat, tangent): 689 return (math_ops.matmul(tangent, mat, transpose_b=True) + 690 math_ops.matmul(mat, tangent, transpose_b=True)) 691 692 self.assertAllClose(result1, result2) 693 self.assertAllClose(_expected(m1, tangent1), acc.jvp(result1)) 694 self.assertAllClose(_expected(m2, tangent2), acc.jvp(result2)) 695 696 @test_util.assert_no_new_pyobjects_executing_eagerly 697 def testHVPMemory(self): 698 699 def fun(x): 700 return math_ops.reduce_prod(math_ops.tanh(x)**2) 701 702 primals = constant_op.constant([1., 2., 3.]) 703 tangents = constant_op.constant([3., 4., 5.]) 704 _hvp(fun, (primals,), (tangents,)) 705 706 @test_util.assert_no_new_pyobjects_executing_eagerly 707 def testHVPCorrectness(self): 708 709 def fun(x): 710 return math_ops.reduce_prod(math_ops.tanh(x)**2) 711 712 primals = constant_op.constant([1., 2., 3.]) 713 tangents = constant_op.constant([3., 4., 5.]) 714 forwardback_hvp_eager, = _hvp(fun, (primals,), (tangents,)) 715 forwardback_hvp_function, = def_function.function(_hvp)(fun, (primals,), 716 (tangents,)) 717 718 with backprop.GradientTape(persistent=True) as g: 719 g.watch(primals) 720 with backprop.GradientTape() as gg: 721 gg.watch(primals) 722 out = fun(primals) 723 grad = array_ops.unstack(gg.gradient(out, primals)) 724 hessian = [] 725 for i in range(3): 726 hessian.append(g.gradient(grad[i], primals)) 727 hessian = array_ops.stack(hessian, axis=0) 728 backback_hvp = math_ops.tensordot(hessian, tangents, axes=1) 729 730 self.assertAllClose(backback_hvp, forwardback_hvp_eager) 731 self.assertAllClose(backback_hvp, forwardback_hvp_function) 732 733 @test_util.assert_no_new_pyobjects_executing_eagerly 734 def testShouldRecordAndStopRecord(self): 735 c = constant_op.constant(1.) 736 c_tangent = constant_op.constant(2.) 737 with forwardprop.ForwardAccumulator(c, c_tangent) as acc: 738 with backprop.GradientTape() as tape: 739 self.assertFalse(tape_lib.should_record_backprop([c])) 740 self.assertEqual(1, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) 741 tape.watch(c) 742 self.assertEqual(2, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) 743 self.assertTrue(tape_lib.should_record_backprop([c])) 744 with tape_lib.stop_recording(): 745 self.assertEqual(0, 746 pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) 747 self.assertFalse(tape_lib.should_record_backprop([c])) 748 d = c * 2. 749 self.assertEqual(2, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) 750 self.assertTrue(tape_lib.should_record_backprop([c])) 751 self.assertFalse(tape_lib.should_record_backprop([d])) 752 self.assertIsNone(acc.jvp(d)) 753 self.assertIsNone(tape.gradient(d, c)) 754 755 @test_util.assert_no_new_pyobjects_executing_eagerly 756 def testRecordingSelectively(self): 757 c = constant_op.constant(1.) 758 c_tangent = constant_op.constant(2.) 759 with forwardprop.ForwardAccumulator(c, c_tangent) as acc: 760 with backprop.GradientTape(persistent=True) as tape: 761 tape.watch(c) 762 with tape_lib.stop_recording(): 763 two = constant_op.constant(2.) 764 d = c * two 765 three = constant_op.constant(3.) 766 e = c * three 767 self.assertIsNone(acc.jvp(d)) 768 self.assertIsNone(acc.jvp(e)) 769 self.assertIsNone(tape.gradient(d, c)) 770 self.assertIsNone(tape.gradient(e, c)) 771 tape_lib.record_operation_forwardprop_only( 772 "CustomForwardMul", [d], [c, two], lambda dd: (two * dd, c * dd), 773 None) 774 tape_lib.record_operation_backprop_only("CustomBackwardMul", [e], 775 [c, three], lambda de: 776 (three * de, c * de)) 777 self.assertAllClose(4., acc.jvp(d)) 778 self.assertIsNone(acc.jvp(e)) 779 self.assertIsNone(tape.gradient(d, c)) 780 self.assertAllClose(3., tape.gradient(e, c)) 781 782 @test_util.assert_no_new_pyobjects_executing_eagerly 783 def testOpWithNoTrainableOutputs(self): 784 v = variables.Variable(1.) 785 with forwardprop.ForwardAccumulator(v, 11.): 786 v.assign_sub(0.5) 787 self.assertAllClose(0.5, self.evaluate(v)) 788 789 # TODO(b/141025187): Add a no_new_pyobjects decorator. 790 def testVariableReadInFunction(self): 791 v = variables.Variable(1.) 792 with forwardprop.ForwardAccumulator(v, 11.) as acc: 793 794 @def_function.function 795 def f(): 796 return v.read_value(), 2. * v.read_value() 797 798 result = f() 799 self.assertAllClose((1.0, 2.), result) 800 self.assertAllClose((11., 22.), acc.jvp(result)) 801 802 @parameterized.named_parameters([("ForwardPropFirst", True), 803 ("TapeFirst", False)]) 804 def testForwardOverBackwardMemoryEfficiency(self, forward_prop_first): 805 # Watching depends on nesting, not creation order 806 c = constant_op.constant(1.) 807 if forward_prop_first: 808 forward_accumulator = forwardprop.ForwardAccumulator(c, .1) 809 gradient_tape = backprop.GradientTape() 810 else: 811 gradient_tape = backprop.GradientTape() 812 forward_accumulator = forwardprop.ForwardAccumulator(c, .1) 813 try: 814 gc.disable() 815 with gradient_tape as tape: 816 # Adding and removing the tape multiple times in different nesting 817 # patterns does not affect watch ordering. 818 pass 819 with forward_accumulator as acc: 820 with gradient_tape as tape: 821 tape.watch(c) 822 d = math_ops.cos(c) 823 self.assertFalse(tape_lib.should_record_backprop((acc.jvp(d),))) 824 e = math_ops.cos(acc.jvp(d)) 825 math_ops.cos(e) 826 weak_e = weakref.ref(e) 827 del e 828 self.assertIsNone(weak_e()) 829 self.assertIsNone(tape.gradient(acc.jvp(d), c)) 830 finally: 831 gc.enable() 832 833 @parameterized.named_parameters([("ForwardPropFirst", True), 834 ("TapeFirst", False)]) 835 def testBackwardOverForward(self, forward_prop_first): 836 c = constant_op.constant(1.) 837 # Watching depends on nesting, not creation order 838 if forward_prop_first: 839 forward_accumulator = forwardprop.ForwardAccumulator(c, .1) 840 gradient_tape = backprop.GradientTape() 841 else: 842 gradient_tape = backprop.GradientTape() 843 forward_accumulator = forwardprop.ForwardAccumulator(c, .1) 844 with gradient_tape as tape: 845 with forward_accumulator as acc: 846 tape.watch(c) 847 d = math_ops.cos(c) 848 self.assertTrue(tape_lib.should_record_backprop((acc.jvp(d),))) 849 self.assertAllClose(-.1 * math_ops.cos(1.), tape.gradient(acc.jvp(d), c)) 850 851 @test_util.assert_no_new_pyobjects_executing_eagerly 852 def testRecordingWithJVPIndices(self): 853 c = constant_op.constant(1.) 854 with forwardprop.ForwardAccumulator(c, 10.) as acc: 855 packed_input_tangents = forwardprop_util.pack_tangents([c]).tangents 856 self.assertAllClose([10.], packed_input_tangents) 857 d = constant_op.constant(2.) 858 d_tangent = constant_op.constant(3.) 859 tape_lib.record_operation_forwardprop_only("FunctionWithInlineJVPs", 860 [d] + [d_tangent], 861 [c] + packed_input_tangents, 862 None, (((0, 1),),)) 863 self.assertAllClose(3., acc.jvp(d)) 864 865 @test_util.assert_no_new_pyobjects_executing_eagerly 866 def testSpecialForwardFunctionUsed(self): 867 c = constant_op.constant(1.) 868 d = constant_op.constant(2.) 869 e = constant_op.constant(3.) 870 with forwardprop.ForwardAccumulator(c, 10.) as acc: 871 tape_lib.record_operation("ForwardIsSpecial", [d], [c], None, 872 lambda jvp: [-2. * jvp]) 873 self.assertAllClose(-20., acc.jvp(d)) 874 tape_lib.record_operation("ForwardIsSpecial2", [], [], None, lambda: []) 875 tape_lib.record_operation("ForwardIsSpecial3", [e], [d], None, 876 lambda x: [x]) 877 self.assertAllClose(-20., acc.jvp(e)) 878 879 @test_util.assert_no_new_pyobjects_executing_eagerly 880 def testVariableWatched(self): 881 v = variables.Variable([1., 2., 3.]) 882 with forwardprop.ForwardAccumulator(v, constant_op.constant([.1, -.2, 883 .3])) as acc: 884 self.assertAllClose([.1, -.2, .3], acc.jvp(v)) 885 x = v * 2. 886 self.assertAllClose([.2, -.4, .6], acc.jvp(x)) 887 x2 = v + .1 888 self.assertAllClose([.1, -.2, .3], acc.jvp(x2)) 889 890 def testUnconnectedGradients(self): 891 x = constant_op.constant(-1.) 892 with forwardprop.ForwardAccumulator(x, 0.1) as acc: 893 self.assertAllClose(0.1, acc.jvp(x, unconnected_gradients="zero")) 894 self.assertAllClose(0.1, acc.jvp(x, unconnected_gradients="none")) 895 y = constant_op.constant(-2.) 896 self.assertAllClose(0.0, acc.jvp(y, unconnected_gradients="zero")) 897 self.assertIsNone(acc.jvp(y, unconnected_gradients="none")) 898 899 # TODO(kkb): One weakref instance is created with warmup_iters=2, 900 # investigate. 901 @test_util.assert_no_new_pyobjects_executing_eagerly(warmup_iters=3) 902 def testVariableWatchedFunction(self): 903 904 class _Model(module.Module): 905 906 def __init__(self): 907 self._v = None 908 909 @def_function.function 910 def compute_jvps(self): 911 if self._v is None: 912 self._v = variables.Variable([1., 2., 3.]) 913 with forwardprop.ForwardAccumulator(self._v, 914 constant_op.constant([.1, -.2, 915 .3])) as acc: 916 x = self._v * 2. 917 x2 = self._v + .1 918 return acc.jvp((self._v, x, x2)) 919 920 model = _Model() 921 v_jvp, x_jvp, x2_jvp = model.compute_jvps() 922 self.assertAllClose([.1, -.2, .3], v_jvp) 923 self.assertAllClose([.2, -.4, .6], x_jvp) 924 self.assertAllClose([.1, -.2, .3], x2_jvp) 925 926 def testIndexSlicesGrad(self): 927 x = constant_op.constant([1.]) 928 929 with forwardprop.ForwardAccumulator(x, constant_op.constant([3.])) as acc: 930 y = array_ops.gather(x, 0) 931 self.assertAllClose(3., acc.jvp(y)) 932 933 def testIndexSlicesGradInFunction(self): 934 935 @def_function.function 936 def f(a): 937 return array_ops.gather(a, 0) 938 939 x = constant_op.constant([1.]) 940 941 with forwardprop.ForwardAccumulator(x, constant_op.constant([3.])) as acc: 942 y = f(x) 943 self.assertAllClose(3., acc.jvp(y)) 944 945 # NOTE: assert_no_new_pyobjects_executing_eagerly fails flakily on this 946 # test... could be something wrong with the test decorator, or some sort of 947 # nondeterministic caching. 948 def testMirroredVariableWatched(self): 949 950 def _replicated(input_tangent): 951 with forwardprop.ForwardAccumulator(v, input_tangent) as acc: 952 self.assertAllClose([.1, -.2, .3], acc.jvp(v)) 953 x = v * 2. 954 self.assertAllClose([.2, -.4, .6], acc.jvp(x)) 955 x2 = v + .1 956 self.assertAllClose([.1, -.2, .3], acc.jvp(x2)) 957 958 strategy = mirrored_strategy.MirroredStrategy() 959 with strategy.scope(): 960 v = variables.Variable([1., 2., 3.]) 961 strategy.run(_replicated, args=(constant_op.constant([.1, -.2, .3]),)) 962 963 # TODO(b/141025187): Add a no_new_pyobjects decorator. 964 def testArgumentUnused(self): 965 v = constant_op.constant(1.) 966 with forwardprop.ForwardAccumulator(v, 11.) as acc: 967 968 @def_function.function 969 def _f(x): 970 del x 971 return constant_op.constant(1.) 972 973 result = _f(v) 974 self.assertAllClose(1.0, result) 975 self.assertIsNone(acc.jvp(result)) 976 977 978@def_function.function 979def _has_loop(iters, y): 980 ret = 0. 981 for i in math_ops.range(iters): 982 ret += y * math_ops.cast(i, dtypes.float32) 983 return ret 984 985 986@def_function.function 987def _has_cond(k, y): 988 if k > 1: 989 ret = 3. * y 990 else: 991 ret = 0. 992 return ret 993 994 995@def_function.function 996def _fprop_while(iters, y): 997 with forwardprop.ForwardAccumulator(y, 1.) as acc: 998 ret = 0. 999 for i in math_ops.range(iters): 1000 ret += y * math_ops.cast(i, dtypes.float32) 1001 return acc.jvp(ret) 1002 1003 1004@def_function.function 1005def _fprop_cond(k, y): 1006 with forwardprop.ForwardAccumulator(y, 1.) as acc: 1007 if k > 1: 1008 ret = 3. * y 1009 else: 1010 ret = 0. 1011 return acc.jvp(ret) 1012 1013 1014class ControlFlowTests(test.TestCase): 1015 1016 @test_util.assert_no_new_pyobjects_executing_eagerly 1017 def testOfFunctionWhile(self): 1018 y = constant_op.constant(1.) 1019 with forwardprop.ForwardAccumulator(y, 1.) as acc: 1020 self.assertAllClose(10., acc.jvp(_has_loop(constant_op.constant(5), y))) 1021 1022 @test_util.assert_no_new_pyobjects_executing_eagerly 1023 def testOfFunctionCond(self): 1024 y = constant_op.constant(1.) 1025 with forwardprop.ForwardAccumulator(y, 1.) as acc: 1026 self.assertAllClose(3., acc.jvp(_has_cond(constant_op.constant(5), y))) 1027 self.assertAllClose(0., acc.jvp(_has_cond(constant_op.constant(0), y))) 1028 1029 @test_util.assert_no_new_pyobjects_executing_eagerly 1030 def testInFunctionWhile(self): 1031 self.assertAllClose( 1032 10., _fprop_while(constant_op.constant(5), constant_op.constant(1.))) 1033 1034 @test_util.assert_no_new_pyobjects_executing_eagerly 1035 def testInFunctionCond(self): 1036 self.assertAllClose( 1037 3., _fprop_cond(constant_op.constant(5), constant_op.constant(1.))) 1038 self.assertAllClose( 1039 0., _fprop_cond(constant_op.constant(0), constant_op.constant(1.))) 1040 1041 1042class HessianTests(test.TestCase, parameterized.TestCase): 1043 1044 def testHessian1D(self): 1045 # Note: stolen from ops/gradients_test.py 1046 m = 4 1047 rng = np.random.RandomState([1, 2, 3]) 1048 mat_value = rng.randn(m, m).astype("float32") 1049 x_value = rng.randn(m).astype("float32") 1050 hess_value = mat_value + mat_value.T 1051 mat = variables.Variable(mat_value) 1052 1053 def _f(x): 1054 return math_ops.reduce_sum(x[:, None] * mat * x[None, :]) 1055 1056 hessian_eager, = _forward_over_back_hessian( 1057 _f, [constant_op.constant(x_value)], 1058 use_pfor=False, 1059 dtype=[dtypes.float32]) 1060 self.assertAllClose(hess_value, hessian_eager) 1061 hessian_function, = def_function.function(_forward_over_back_hessian)( 1062 _f, [constant_op.constant(x_value)], 1063 use_pfor=False, 1064 dtype=[dtypes.float32]) 1065 self.assertAllClose(hess_value, hessian_function) 1066 hessian_pfor, = def_function.function(_forward_over_back_hessian)( 1067 _f, [constant_op.constant(x_value)], 1068 use_pfor=True, 1069 dtype=[dtypes.float32]) 1070 self.assertAllClose(hess_value, hessian_pfor) 1071 1072 1073class BatchTests(test.TestCase, parameterized.TestCase): 1074 1075 @parameterized.parameters([(math_ops.sin, (2, 3), 5), 1076 (math_ops.sin, (2, 3, 4), 10)]) 1077 def testJVPBatchCorrectness(self, f, primal_shape, batch_size): 1078 primals = [random_ops.random_uniform(primal_shape)] 1079 tangent_batch = [random_ops.random_uniform([batch_size, *primal_shape])] 1080 self.assertAllClose( 1081 _jvp_batch(f, primals, tangent_batch)[1], 1082 _jvp_batch_matmul(f, primals, *tangent_batch)) 1083 1084 def testBatchCorrectness(self): 1085 x = constant_op.constant(2.0) 1086 y = constant_op.constant(5.0) 1087 tangents = ( 1088 constant_op.constant([1., 0., 1.]), 1089 constant_op.constant([0., 1., 1.]), 1090 ) 1091 with forwardprop.ForwardAccumulator._batch_accumulator((x, y), 1092 tangents) as acc: 1093 z = x * y 1094 self.assertAllClose(acc.jvp(z), constant_op.constant([5.0, 2.0, 7.0])) 1095 1096 @parameterized.named_parameters([("ForwardPropFirst", True), 1097 ("TapeFirst", False)]) 1098 def testBatchBackwardOverForward(self, forward_prop_first): 1099 x = constant_op.constant(1.) 1100 tangents = random_ops.random_normal(shape=[10], seed=1) 1101 expected = [-t * math_ops.cos(1.) for t in tangents] 1102 if forward_prop_first: 1103 batch_acc = forwardprop.ForwardAccumulator._batch_accumulator(x, tangents) 1104 gradient_tape = backprop.GradientTape(persistent=True) 1105 else: 1106 gradient_tape = backprop.GradientTape(persistent=True) 1107 batch_acc = forwardprop.ForwardAccumulator._batch_accumulator(x, tangents) 1108 with gradient_tape as tape: 1109 with batch_acc as acc: 1110 tape.watch(x) 1111 y = math_ops.cos(x) 1112 self.assertTrue(tape_lib.should_record_backprop((acc.jvp(y),))) 1113 jvps = acc.jvp(y) 1114 d2y_dx2 = [tape.gradient(dy_dx, x) for dy_dx in jvps] 1115 self.assertAllClose(expected, d2y_dx2) 1116 1117 1118if __name__ == "__main__": 1119 # TODO(allenl): Also test with 1.x-style graph mode. 1120 ops.enable_eager_execution() 1121 test.main() 1122