1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for forward-mode automatic differentiation.""" 16 17import functools 18import threading 19 20from tensorflow.python import pywrap_tfe 21from tensorflow.python.eager import backprop 22from tensorflow.python.eager import backprop_util 23from tensorflow.python.eager import execute 24from tensorflow.python.eager import forwardprop_util 25from tensorflow.python.eager import function 26 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops.parallel_for import control_flow_ops 32from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.util import nest 35from tensorflow.python.util.tf_export import tf_export 36 37 38# Dictionary mapping from op names to special-cased jvp functions. Otherwise 39# backward functions are transposed on the tape. 40_SPECIAL_CASES = {} 41 42 43def _identity_jvp(attr_tuple, inputs, outputs, tangents): 44 # Special-cased mostly for resource handles, where creating ones Tensors from 45 # handle data for transposing the backward function on the tape is error-prone 46 # (even if we get good handle data, partially defined shapes are an issue). 47 del attr_tuple, inputs, outputs 48 return [array_ops.identity(t) for t in tangents] 49 50 51_SPECIAL_CASES["Identity"] = _identity_jvp 52 53 54def _read_variable_jvp(attr_tuple, inputs, outputs, tangents): 55 # Like for Identity, this special case means we don't need to create 56 # variable-shaped Tensors from resource handles. 57 del attr_tuple, inputs, outputs 58 return [array_ops.identity(t) for t in tangents] 59 60 61_SPECIAL_CASES["ReadVariableOp"] = _read_variable_jvp 62 63 64_TRACE_COUNT_CONSISTENCY_LOCK = threading.Lock() 65# Map from op names to number of traces of _jvp_helper. Used to cap the number 66# of traces due to shape differences while still specializing where possible. 67_TRACE_COUNT = {} 68 69 70def _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents): 71 """Computes a Jacobian-vector product for an op. 72 73 Note that this function would be wasteful if executed eagerly. It runs the 74 backward gradient function and throws away the result just to record its 75 operations on a GradientTape. These unused ops are pruned away when this 76 function is traced. 77 78 Args: 79 op_name: A string, the type of operation being executed. 80 attr_tuple: Attributes of the operation. 81 inputs: A flat list of input Tensors to the operation. 82 outputs: A flat list of output Tensors from the operation. 83 tangents: A flat list of Tensors, same shape as `inputs`. 84 85 Returns: 86 A flat list of tangents corresponding to `outputs`. 87 """ 88 with _TRACE_COUNT_CONSISTENCY_LOCK: 89 # Just make sure writes don't clobber each other's increments; reads in 90 # _jvp_dispatch do not lock. 91 _TRACE_COUNT[op_name] = _TRACE_COUNT.get(op_name, 0) + 1 92 93 special_case = _SPECIAL_CASES.get(op_name, None) 94 if special_case is not None: 95 return special_case(attr_tuple, inputs, outputs, tangents) 96 if not outputs: 97 # tape.gradients([], inputs) doesn't make much sense 98 return [] 99 # Generally inner GradientTapes won't function while outer accumulators are 100 # recording. We temporarily reset forwardprop state to allow GradientTapes to 101 # function here. 102 with forwardprop_util.push_forwardprop_state(): 103 trainable_inputs = [] 104 trainable_indices = [] 105 nontrivial_tangents = [] 106 for input_index, tensor in enumerate(inputs): 107 if backprop_util.IsTrainable(tensor): 108 trainable_inputs.append(tensor) 109 trainable_indices.append(input_index) 110 nontrivial_tangents.append(tangents[input_index]) 111 112 with backprop.GradientTape() as transpose_tape: 113 with backprop.GradientTape() as backfunc_tape: 114 backfunc_tape.watch(trainable_inputs) 115 execute.record_gradient(op_name, inputs, attr_tuple, outputs) 116 117 forwardprop_aids = [] 118 trainable_outputs = [] 119 nontrivial_output_indices = [] 120 for output_index, output in enumerate(outputs): 121 if backprop_util.IsTrainable(output): 122 forwardprop_aids.append( 123 array_ops.ones_like(output, name="unused_forwardprop_aid")) 124 trainable_outputs.append(output) 125 nontrivial_output_indices.append(output_index) 126 127 transpose_tape.watch(forwardprop_aids) 128 grads = backfunc_tape.gradient( 129 trainable_outputs, 130 trainable_inputs, 131 forwardprop_aids, 132 unconnected_gradients=UnconnectedGradients.ZERO) 133 nontrivial_output_tangents = transpose_tape.gradient( 134 grads, forwardprop_aids, output_gradients=nontrivial_tangents) 135 output_tangents = [None] * len(outputs) 136 for index, tangent in zip(nontrivial_output_indices, 137 nontrivial_output_tangents): 138 output_tangents[index] = tangent 139 return output_tangents 140 141 142def _jvp_helper_wrapper(op_name, attr_tuple, inputs, outputs, tangents, 143 use_batch): 144 """Computes a batch of Jacobian-vector product for an op. 145 146 Args: 147 op_name: A string, the type of operation being executed. 148 attr_tuple: Attributes of the operation. 149 inputs: A flat list of input Tensors to the operation. 150 outputs: A flat list of output Tensors from the operation. 151 tangents: A flat list of Tensors, compatible with shape `[None] + 152 input_shape`. 153 use_batch: A bool, True to vetorize over batch of tangents of shape `[None] 154 + input_shape`. 155 156 Returns: 157 A flat list of tangents compatible with `outputs` 158 or `[None] + output_shape`. 159 160 Raises: 161 ValueError: if tangent shapes are not compatible with input shapes. 162 """ 163 if use_batch: 164 for primal, tangent in zip(inputs, tangents): 165 if not tangent.shape.is_compatible_with([None] + primal.shape): 166 raise ValueError("Tangent {} was expected to be of shape " 167 "{} but is instead of shape {}".format( 168 tangent, [None] + primal.shape, tangent.shape)) 169 170 return control_flow_ops.vectorized_map( 171 functools.partial(_jvp_helper, op_name, attr_tuple, inputs, outputs), 172 tangents, 173 ) 174 return _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents) 175 176 177# TODO(allenl): reduce_retracing for gradients which rely on static 178# shape information are underspecialized. We may want hand-written forward 179# implementations, or a more satisfying story about how we re-specialize 180# gradients which were traced with relaxed shapes (e.g. use conds instead of 181# trace-time Python logic). 182# 183# Using function.defun rather than def_function.function avoids 184# tf.config.run_functions_eagerly(True). `_jvp_helper` doesn't successfully run 185# eagerly (infinite recursion), and even if it did it would use extra memory and 186# run unnecessary computation. The function does not create variables, so the 187# two symbols are otherwise equivalent. 188_jvp_relaxed_shapes = function.defun( 189 _jvp_helper_wrapper, reduce_retracing=True) 190_jvp_exact_shapes = function.defun( 191 _jvp_helper_wrapper, reduce_retracing=False) 192 193# The maximum number of exact-shape traces to perform for a single op before 194# switching to shape relaxation. 195_TRACE_COUNT_LIMIT = 32 196 197 198def _jvp_dispatch(op_name, 199 attr_tuple, 200 inputs, 201 outputs, 202 tangents, 203 use_batch=False): 204 """Determine which forwardprop function to call.""" 205 # Note that this _TRACE_COUNT read races with writes. That's fine, it just 206 # means we may trace a few more exact shapes before moving on to relaxation. 207 if _TRACE_COUNT.get(op_name, 0) < _TRACE_COUNT_LIMIT: 208 return _jvp_exact_shapes(op_name, attr_tuple, inputs, outputs, tangents, 209 use_batch) 210 return _jvp_relaxed_shapes(op_name, attr_tuple, inputs, outputs, tangents, 211 use_batch) 212 213 214pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch) 215 216 217@tf_export("autodiff.ForwardAccumulator", v1=[]) 218class ForwardAccumulator(): 219 """Computes Jacobian-vector products ("JVP"s) using forward-mode autodiff. 220 221 Compare to `tf.GradientTape` which computes vector-Jacobian products ("VJP"s) 222 using reverse-mode autodiff (backprop). Reverse mode is more attractive when 223 computing gradients of a scalar-valued function with respect to many inputs 224 (e.g. a neural network with many parameters and a scalar loss). Forward mode 225 works best on functions with many outputs and few inputs. Since it does not 226 hold on to intermediate activations, it is much more memory efficient than 227 backprop where it is applicable. 228 229 Consider a simple linear regression: 230 231 >>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]]) 232 >>> targets = tf.constant([[1.], [-1.]]) 233 >>> dense = tf.keras.layers.Dense(1) 234 >>> dense.build([None, 2]) 235 >>> with tf.autodiff.ForwardAccumulator( 236 ... primals=dense.kernel, 237 ... tangents=tf.constant([[1.], [0.]])) as acc: 238 ... loss = tf.reduce_sum((dense(x) - targets) ** 2.) 239 >>> acc.jvp(loss) 240 <tf.Tensor: shape=(), dtype=float32, numpy=...> 241 242 The example has two variables containing parameters, `dense.kernel` (2 243 parameters) and `dense.bias` (1 parameter). Considering the training data `x` 244 as a constant, this means the Jacobian matrix for the function mapping from 245 parameters to loss has one row and three columns. 246 247 With forwardprop, we specify a length-three vector in advance which multiplies 248 the Jacobian. The `primals` constructor argument is the parameter (a 249 `tf.Tensor` or `tf.Variable`) we're specifying a vector for, and the 250 `tangents` argument is the "vector" in Jacobian-vector product. If our goal is 251 to compute the entire Jacobian matrix, forwardprop computes one column at a 252 time while backprop computes one row at a time. Since the Jacobian in the 253 linear regression example has only one row, backprop requires fewer 254 invocations: 255 256 >>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]]) 257 >>> targets = tf.constant([[1.], [-1.]]) 258 >>> dense = tf.keras.layers.Dense(1) 259 >>> dense.build([None, 2]) 260 >>> loss_fn = lambda: tf.reduce_sum((dense(x) - targets) ** 2.) 261 >>> kernel_fprop = [] 262 >>> with tf.autodiff.ForwardAccumulator( 263 ... dense.kernel, tf.constant([[1.], [0.]])) as acc: 264 ... kernel_fprop.append(acc.jvp(loss_fn())) 265 >>> with tf.autodiff.ForwardAccumulator( 266 ... dense.kernel, tf.constant([[0.], [1.]])) as acc: 267 ... kernel_fprop.append(acc.jvp(loss_fn())) 268 >>> with tf.autodiff.ForwardAccumulator(dense.bias, tf.constant([1.])) as acc: 269 ... bias_fprop = acc.jvp(loss_fn()) 270 >>> with tf.GradientTape() as tape: 271 ... loss = loss_fn() 272 >>> kernel_grad, bias_grad = tape.gradient(loss, (dense.kernel, dense.bias)) 273 >>> np.testing.assert_allclose( 274 ... kernel_grad, tf.stack(kernel_fprop)[:, tf.newaxis]) 275 >>> np.testing.assert_allclose(bias_grad, bias_fprop[tf.newaxis]) 276 277 Implicit in the `tape.gradient` call is a length-one vector which 278 left-multiplies the Jacobian, a vector-Jacobian product. 279 280 `ForwardAccumulator` maintains JVPs corresponding primal tensors it is 281 watching, derived from the original `primals` specified in the constructor. As 282 soon as a primal tensor is deleted, `ForwardAccumulator` deletes the 283 corresponding JVP. 284 285 `acc.jvp(x)` retrieves `acc`'s JVP corresponding to the primal tensor `x`. It 286 does not perform any computation. `acc.jvp` calls can be repeated as long as 287 `acc` is accessible, whether the context manager is active or not. New JVPs 288 are only computed while the context manager is active. 289 290 Note that `ForwardAccumulator`s are always applied in the order their context 291 managers were entered, so inner accumulators will not see JVP computation from 292 outer accumulators. Take higher-order JVPs from outer accumulators: 293 294 >>> primal = tf.constant(1.1) 295 >>> with tf.autodiff.ForwardAccumulator(primal, tf.constant(1.)) as outer: 296 ... with tf.autodiff.ForwardAccumulator(primal, tf.constant(1.)) as inner: 297 ... primal_out = primal ** tf.constant(3.5) 298 >>> inner_jvp = inner.jvp(primal_out) 299 >>> inner_jvp # 3.5 * 1.1 ** 2.5 300 <tf.Tensor: shape=(), dtype=float32, numpy=4.4417057> 301 >>> outer.jvp(inner_jvp) # 3.5 * 2.5 * 1.1 ** 1.5 302 <tf.Tensor: shape=(), dtype=float32, numpy=10.094786> 303 304 Reversing the collection in the last line to instead retrieve 305 `inner.jvp(outer.jvp(primal_out))` will not work. 306 307 Strict nesting also applies to combinations of `ForwardAccumulator` and 308 `tf.GradientTape`. More deeply nested `GradientTape` objects will ignore the 309 products of outer `ForwardAccumulator` objects. This allows (for example) 310 memory-efficient forward-over-backward computation of Hessian-vector products, 311 where the inner `GradientTape` would otherwise hold on to all intermediate 312 JVPs: 313 314 >>> v = tf.Variable([1., 2.]) 315 >>> with tf.autodiff.ForwardAccumulator( 316 ... v, 317 ... # The "vector" in Hessian-vector product. 318 ... tf.constant([1., 0.])) as acc: 319 ... with tf.GradientTape() as tape: 320 ... y = tf.reduce_sum(v ** 3.) 321 ... backward = tape.gradient(y, v) 322 >>> backward # gradient from backprop 323 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 3., 12.], dtype=float32)> 324 >>> acc.jvp(backward) # forward-over-backward Hessian-vector product 325 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 0.], dtype=float32)> 326 """ 327 328 def __init__(self, primals, tangents): 329 """Specify tensors to watch and their Jacobian-vector products. 330 331 Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix 332 (a Jacobian-vector product) for the function computed while this accumulator 333 is active. Since JVPs are computed in forward mode as the computation 334 happens, this vector must be supplied in advance. 335 336 Listing a single tensor multiple times in `primals` raises an 337 exception. Excluding a tensor from `primals` is equivalent to watching it 338 with a tangent tensor of zeros. 339 340 Args: 341 primals: A tensor or nested structure of tensors to watch. 342 tangents: A tensor or nested structure of tensors, with the same nesting 343 structure as `primals`, with each element being a vector with the same 344 size as the corresponding primal element. 345 346 Raises: 347 ValueError: If the same tensor or variable is specified multiple times in 348 `primals`. 349 """ 350 self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(False) 351 self._recording = False 352 primal_ids = set() 353 for primal in nest.flatten(primals): 354 if id(primal) in primal_ids: 355 raise ValueError( 356 "Tensor {} was specified as a primal multiple times. This may " 357 "indicate an error. If it was intended, please sum the " 358 "corresponding tangents.") 359 primal_ids.add(id(primal)) 360 self._watch(primals, tangents) 361 362 def __enter__(self): 363 self._push_accumulator() 364 return self 365 366 def __exit__(self, typ, value, traceback): 367 if self._recording: 368 self._pop_accumulator() 369 370 def _push_accumulator(self): 371 if self._recording: 372 raise ValueError("Accumulator is already recording.") 373 pywrap_tfe.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator) 374 self._recording = True 375 376 def _pop_accumulator(self): 377 if not self._recording: 378 raise ValueError("Accumulator is not recording.") 379 pywrap_tfe.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator) 380 self._recording = False 381 382 def _watch(self, primals, tangents): 383 """Ensures that `primals` are being traced by this accumulator. 384 385 Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix 386 (a Jacobian-vector product) for the function computed while this accumulator 387 is active. Since JVPs are computed in forward mode as the computation 388 happens, this vector must be supplied in advance. 389 390 Watching a single tensor multiple times sums each of its `tangents`. Any 391 un-watched tensor has zeros for its tangent vector. 392 393 Args: 394 primals: A Tensor or list of Tensors. 395 tangents: A Tensor or list of Tensors matching `primals`. 396 """ 397 398 def _watch(primal, tangent): 399 if not primal.dtype.is_floating: 400 logging.log_first_n( 401 logging.WARN, "The dtype of the watched primal must be " 402 "floating (e.g. tf.float32), got %r", 5, primal.dtype) 403 tangent = ops.convert_to_tensor(tangent, dtype=primal.dtype) 404 if hasattr(primal, "handle"): 405 # Run convert_to_tensor to get the captured handle from whichever 406 # function we're running if necessary. 407 primal = ops.convert_to_tensor(primal.handle) 408 pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator, primal, 409 tangent) 410 411 nest.map_structure(_watch, primals, tangents) 412 413 def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE): 414 """Fetches the Jacobian-vector product computed for `primals`. 415 416 Note that this method performs no computation, and simply looks up a JVP 417 that was already computed (unlike backprop using a `tf.GradientTape`, where 418 the computation happens on the call to `tape.gradient`). 419 420 Args: 421 primals: A watched Tensor or structure of Tensors to fetch the JVPs for. 422 unconnected_gradients: A value which can either hold 'none' or 'zero' and 423 alters the value which will be returned if no JVP was computed for 424 `primals`. The possible values and effects are detailed in 425 'tf.UnconnectedGradients' and it defaults to 'none'. 426 427 Returns: 428 Tensors with the same shapes and dtypes as `primals`, or None if no JVP 429 is available. 430 """ 431 unconnected_gradients = UnconnectedGradients(unconnected_gradients) 432 if self._accumulator is None: 433 raise ValueError("Called jvp() without first tracing anything.") 434 435 def _fetch_jvp(tensor): 436 if hasattr(tensor, "handle"): 437 unwrapped_tensor = ops.convert_to_tensor(tensor.handle) 438 else: 439 unwrapped_tensor = tensor 440 result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator, 441 unwrapped_tensor) 442 if result is None and unconnected_gradients == UnconnectedGradients.ZERO: 443 result = array_ops.zeros_like(tensor) 444 return result 445 446 return nest.map_structure(_fetch_jvp, primals) 447 448 @classmethod 449 def _batch_accumulator(cls, primals, tangents): 450 """Factory constructor to test accumulator on batches of tangents. 451 452 Args: 453 primals: A tensor or nested structure of tensors to watch. 454 tangents: A tensor or nested structure of tensors, with the same nesting 455 structure as `primals`, with each element being a vector with compatible 456 shape `[None] + primal.shape` of the corresponding primal element. 457 458 Returns: 459 A batch accumulator object. 460 """ 461 acc = super(ForwardAccumulator, cls).__new__(cls, primals, tangents) 462 acc._recording = False 463 acc._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(True) 464 primal_ids = set() 465 for primal, tangent in zip(nest.flatten(primals), nest.flatten(tangents)): 466 tangent.shape.assert_is_compatible_with( 467 tensor_shape.TensorShape([None]) + primal.shape) 468 if id(primal) in primal_ids: 469 raise ValueError( 470 "Tensor {} was specified as a primal multiple times. This may " 471 "indicate an error. If it was intended, please sum the " 472 "corresponding tangents.") 473 primal_ids.add(id(primal)) 474 acc._watch(primals, tangents) 475 return acc 476