1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21 22from absl.testing import parameterized 23import numpy as np 24import tensorflow as tf 25 26 27def _jvp(f, primals, tangents): 28 """Compute the jacobian of `f` at `primals` multiplied by `tangents`.""" 29 with tf.autodiff.ForwardAccumulator(primals, tangents) as acc: 30 primals_out = f(*primals) 31 return primals_out, acc.jvp( 32 primals_out, unconnected_gradients=tf.UnconnectedGradients.ZERO) 33 34 35def _jacfwd(f, primals): 36 """Compute the jacobian of `f` at `primals` using forward-mode autodiff.""" 37 jac_flat = [] 38 flat_primals = tf.nest.flatten(primals) 39 tangent_mask = [tf.zeros_like(primal) for primal in flat_primals] 40 for primal_index, primal in enumerate(flat_primals): 41 primal_vector = tf.reshape(primal, [-1]) 42 primal_vector_length = tf.size(primal_vector) 43 jac_columns = [] 44 for element_index in tf.range(primal_vector_length): 45 mask = tf.one_hot(element_index, primal_vector_length) 46 tangent_mask[primal_index] = tf.reshape(mask, tf.shape(primal)) 47 jac_columns.append( 48 tf.nest.map_structure( 49 functools.partial(tf.reshape, shape=[-1]), 50 _jvp(f, primals, tf.nest.pack_sequence_as(primals, 51 tangent_mask))[1])) 52 jac_flat.append(tf.stack(jac_columns, axis=1)) 53 tangent_mask[primal_index] = tf.zeros_like(primal) 54 return tf.nest.pack_sequence_as(primals, jac_flat) 55 56 57def _grad(f, argnums=0): 58 """Return a function which computes the gradient of `f`.""" 59 60 def _f(*params): 61 with tf.GradientTape() as tape: 62 tape.watch(params) 63 primals_out = f(*params) 64 return tape.gradient( 65 primals_out, 66 params[argnums], 67 unconnected_gradients=tf.UnconnectedGradients.ZERO) 68 69 return _f 70 71 72def _hvp(f, primals, tangents): 73 """Compute a forward-over-back Hessian-vector product.""" 74 with tf.autodiff.ForwardAccumulator(primals, tangents) as acc: 75 with tf.GradientTape() as tape: 76 tape.watch(primals) 77 f_out = f(*primals) 78 f_out.shape.assert_is_compatible_with([]) 79 return acc.jvp(tape.gradient(f_out, primals)) 80 81 82def _vectorize_parameters(f, params, use_pfor, dtype): 83 """Loop over `params`, providing a one-hot mask to `f` for each.""" 84 parameter_sizes = [tf.size(param) for param in params] 85 total_size = tf.math.add_n(parameter_sizes) 86 87 def _wrapper(index): 88 full_onehot = tf.one_hot(index, total_size) 89 split_onehot = tf.split(full_onehot, parameter_sizes) 90 tangents = [ 91 tf.reshape(v, tf.shape(param)) 92 for param, v in zip(params, split_onehot) 93 ] 94 return f(tangents) 95 96 if use_pfor: 97 return tf.vectorized_map(_wrapper, tf.range(total_size)) 98 else: 99 return tf.map_fn(_wrapper, tf.range(total_size), dtype) 100 101 102def _forward_over_back_hessian(f, params, use_pfor, dtype=None): 103 """Computes the full Hessian matrix for the scalar-valued f(*params). 104 105 Args: 106 f: A function taking `params` and returning a scalar. 107 params: A possibly nested structure of tensors. 108 use_pfor: If true, uses `tf.vectorized_map` calls instead of looping. 109 dtype: Required if `use_pfor=False`. A possibly nested structure of dtypes 110 (e.g. `tf.float32`) matching the structure of `f`'s returns. 111 112 Returns: 113 A possibly nested structure of matrix slices corresponding to `params`. Each 114 slice has shape [P, p_s] where `p_s` is the number of parameters (`tf.size`) 115 in the corresponding element of `params` and `P` is the total number of 116 parameters (`sum_s(p_s)`). The full matrix can be obtained by concatenating 117 along the second axis. 118 """ 119 return _vectorize_parameters( 120 functools.partial(_hvp, f, params), 121 params, 122 use_pfor=use_pfor, 123 dtype=dtype) 124 125 126def _test_gradients(testcase, 127 f, 128 primals, 129 order, 130 delta=1e-3, 131 rtol=1e-2, 132 atol=1e-6): 133 """Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients.""" 134 if order < 1: 135 raise ValueError( 136 "`order` should be a positive integer, got '{}'.".format(order)) 137 if order > 1: 138 _test_gradients( 139 testcase=testcase, 140 f=_grad(f), 141 primals=primals, 142 order=order - 1, 143 delta=delta, 144 rtol=rtol, 145 atol=atol) 146 sym_jac_back, num_jac = tf.test.compute_gradient(f, primals, delta=delta) 147 testcase.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol) 148 sym_jac_fwd = _jacfwd(f, primals) 149 testcase.assertAllClose(num_jac, sym_jac_fwd, rtol=rtol, atol=atol) 150 # And the symbolic computations should be much closer. 151 testcase.assertAllClose(sym_jac_back, sym_jac_fwd) 152 153 154class ForwardpropTest(tf.test.TestCase, parameterized.TestCase): 155 156 @parameterized.named_parameters([ 157 ("Dense", [[0.1]], functools.partial(tf.keras.layers.Dense, 5)), 158 ("Conv2D", 159 np.reshape( 160 np.arange(start=-1., stop=1., step=2. / (1 * 2 * 4 * 4)), 161 [1, 2, 4, 4]), functools.partial(tf.keras.layers.Conv2D, 2, 2), 1e-3) 162 ]) 163 def testKerasLayers(self, value, op_fn, atol=1e-6): 164 layer = op_fn() 165 input_value = tf.constant(value, dtype=tf.float32) 166 layer.build(input_value.shape) 167 # Make sure the test is deterministic by avoiding random variable 168 # initialization. 169 for v in layer.trainable_variables: 170 v.assign( 171 tf.reshape( 172 tf.range( 173 -1., 174 1., 175 2. / tf.size(v, out_type=tf.float32), 176 dtype=tf.float32), v.shape)) 177 _test_gradients( 178 self, 179 layer, 180 [input_value], 181 atol=atol, 182 # These are linear, so second-order is pretty boring. 183 order=2) 184 185 @parameterized.named_parameters([ 186 ("NonFused", [[0.1], [0.2], [-0.3]], 187 functools.partial(tf.keras.layers.BatchNormalization, fused=False)), 188 ("Fused", [[[[0.1, 2.]]], [[[0.2, -3.]]], [[[-0.3, 4.]]]], 189 functools.partial(tf.keras.layers.BatchNormalization, fused=True)) 190 ]) 191 def testBatchNorm(self, value, op_fn): 192 for training in [True, False]: 193 layer = op_fn() 194 input_value = tf.constant(value, dtype=tf.float32) 195 layer.build(input_value.shape) 196 _test_gradients( 197 self, 198 functools.partial(layer, training=training), [input_value], 199 order=2, 200 atol=1e-3) 201 202 @parameterized.named_parameters([ 203 ("NonFused", [[0.1], [0.2], [-0.3]], 204 functools.partial(tf.keras.layers.BatchNormalization, fused=False)), 205 ("Fused", [[[[0.1, 2.]]], [[[0.2, -3.]]], [[[-0.3, 4.]]]], 206 functools.partial(tf.keras.layers.BatchNormalization, fused=True)) 207 ]) 208 def testBatchNormLayerParamGrads(self, value, op_fn): 209 for training in [True, False]: 210 layer = op_fn() 211 with tf.GradientTape() as tape: 212 input_value = tf.constant(value, dtype=tf.float32) 213 tape.watch(input_value) 214 output = layer(input_value, training=training) 215 jac_back = tape.jacobian(output, 216 [input_value] + layer.trainable_variables) 217 jac_forward = _jacfwd( 218 lambda *args: layer(args[0], training=training), # pylint:disable=cell-var-from-loop 219 [input_value] + layer.trainable_variables) 220 for backward, forward in zip(jac_back, jac_forward): 221 forward = tf.reshape(forward, tf.shape(backward)) 222 self.assertAllClose(backward, forward) 223 224 @parameterized.named_parameters([("Function", tf.function), 225 ("NoFunction", lambda f: f)]) 226 def testVariablesHVP(self, decorator): 227 228 class _Model(tf.Module): 229 230 def __init__(self): 231 self._first_dense = tf.keras.layers.Dense(18) 232 self._conv = tf.keras.layers.Conv2D(2, 2) 233 self._norm = tf.keras.layers.BatchNormalization() 234 self._second_dense = tf.keras.layers.Dense(1) 235 236 def __call__(self, x): 237 x = self._first_dense(x) 238 x = tf.nn.relu(x) 239 x = self._norm(x) 240 x = tf.nn.relu(self._conv(tf.reshape(x, [-1, 2, 3, 3]))) 241 return self._second_dense(x) 242 243 model = _Model() 244 245 def _loss(): 246 input_value = tf.constant([[-0.5, 1.], [0.5, -1.]]) 247 target = tf.constant([[-1.], [2.]]) 248 return tf.math.reduce_sum((model(input_value) - target)**2.) 249 250 @decorator 251 def _compute_hvps(): 252 with tf.GradientTape() as tape: 253 loss = _loss() 254 vector = tape.gradient(loss, model.trainable_variables) 255 variable_input_fn = lambda unused_variables: _loss() 256 forward_over_back_hvp, = _hvp(variable_input_fn, 257 [model.trainable_variables], [vector]) 258 with tf.GradientTape(persistent=True) as tape: 259 tape.watch(model.trainable_variables) 260 loss = _loss() 261 first_grads = tape.gradient(loss, model.trainable_variables) 262 back_over_back_hvp = tape.gradient( 263 first_grads, model.trainable_variables, output_gradients=vector) 264 return forward_over_back_hvp, back_over_back_hvp 265 266 self.assertAllClose(*_compute_hvps(), rtol=1e-5, atol=1e-5) 267 268 def testEmbeddingLayerInFunction(self): 269 270 class M(tf.keras.Model): 271 272 def __init__(self): 273 super(M, self).__init__() 274 self.embed = tf.keras.layers.Embedding(5, 1) 275 self.proj = tf.keras.layers.Dense(1) 276 277 @tf.function 278 def call(self, x): 279 return self.proj(self.embed(x)) 280 281 model = M() 282 model(tf.zeros([3, 3], dtype=tf.int32)) # pylint: disable=not-callable 283 parameters = model.embed.variables 284 tangents = [tf.ones_like(v) for v in parameters] 285 with tf.autodiff.ForwardAccumulator(parameters, tangents): 286 # Note that forwardprop runs alongside the original computation. This test 287 # is just checking that it doesn't crash; correctness is tested in core 288 # TF. 289 model(tf.zeros([3, 3], dtype=tf.int32)) # pylint: disable=not-callable 290 291 292class HessianTests(tf.test.TestCase, parameterized.TestCase): 293 294 @parameterized.named_parameters([("PFor", True), ("MapFn", False)]) 295 def testHessianOfVariables(self, use_pfor): 296 model = tf.keras.layers.Dense(1) 297 model.build([None, 2]) 298 299 def _loss(*unused_args): 300 input_value = tf.constant([[-0.5, 1.], [0.5, -1.]]) 301 target = tf.constant([[-1.], [2.]]) 302 return tf.math.reduce_sum((model(input_value) - target)**2.) 303 304 kernel_hess, bias_hess = _forward_over_back_hessian( 305 _loss, [model.kernel, model.bias], 306 use_pfor=use_pfor, 307 dtype=[tf.float32, tf.float32]) 308 # 3 total parameters, the whole hessian is the 3x3 concatenation 309 self.assertEqual([3, 2, 1], kernel_hess.shape) 310 self.assertEqual([3, 1], bias_hess.shape) 311 full_hessian = tf.concat([tf.reshape(kernel_hess, [3, 2]), bias_hess], 312 axis=1) 313 # The full Hessian should be symmetric. 314 self.assertAllClose(full_hessian, tf.transpose(full_hessian)) 315 316 317if __name__ == "__main__": 318 tf.test.main() 319