1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for forward-mode automatic differentiation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import threading 23 24from tensorflow.python import pywrap_tfe 25from tensorflow.python.eager import backprop 26from tensorflow.python.eager import backprop_util 27from tensorflow.python.eager import execute 28from tensorflow.python.eager import forwardprop_util 29from tensorflow.python.eager import function 30 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_shape 33 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops.parallel_for import control_flow_ops 36from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.util import nest 39from tensorflow.python.util.tf_export import tf_export 40 41 42# Dictionary mapping from op names to special-cased jvp functions. Otherwise 43# backward functions are transposed on the tape. 44_SPECIAL_CASES = {} 45 46 47def _identity_jvp(attr_tuple, inputs, outputs, tangents): 48 # Special-cased mostly for resource handles, where creating ones Tensors from 49 # handle data for transposing the backward function on the tape is error-prone 50 # (even if we get good handle data, partially defined shapes are an issue). 51 del attr_tuple, inputs, outputs 52 return [array_ops.identity(t) for t in tangents] 53 54 55_SPECIAL_CASES["Identity"] = _identity_jvp 56 57 58def _read_variable_jvp(attr_tuple, inputs, outputs, tangents): 59 # Like for Identity, this special case means we don't need to create 60 # variable-shaped Tensors from resource handles. 61 del attr_tuple, inputs, outputs 62 return [array_ops.identity(t) for t in tangents] 63 64 65_SPECIAL_CASES["ReadVariableOp"] = _read_variable_jvp 66 67 68_TRACE_COUNT_CONSISTENCY_LOCK = threading.Lock() 69# Map from op names to number of traces of _jvp_helper. Used to cap the number 70# of traces due to shape differences while still specializing where possible. 71_TRACE_COUNT = {} 72 73 74def _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents): 75 """Computes a Jacobian-vector product for an op. 76 77 Note that this function would be wasteful if executed eagerly. It runs the 78 backward gradient function and throws away the result just to record its 79 operations on a GradientTape. These unused ops are pruned away when this 80 function is traced. 81 82 Args: 83 op_name: A string, the type of operation being executed. 84 attr_tuple: Attributes of the operation. 85 inputs: A flat list of input Tensors to the operation. 86 outputs: A flat list of output Tensors from the operation. 87 tangents: A flat list of Tensors, same shape as `inputs`. 88 89 Returns: 90 A flat list of tangents corresponding to `outputs`. 91 """ 92 with _TRACE_COUNT_CONSISTENCY_LOCK: 93 # Just make sure writes don't clobber each other's increments; reads in 94 # _jvp_dispatch do not lock. 95 _TRACE_COUNT[op_name] = _TRACE_COUNT.get(op_name, 0) + 1 96 97 special_case = _SPECIAL_CASES.get(op_name, None) 98 if special_case is not None: 99 return special_case(attr_tuple, inputs, outputs, tangents) 100 if not outputs: 101 # tape.gradients([], inputs) doesn't make much sense 102 return [] 103 # Generally inner GradientTapes won't function while outer accumulators are 104 # recording. We temporarily reset forwardprop state to allow GradientTapes to 105 # function here. 106 with forwardprop_util.push_forwardprop_state(): 107 trainable_inputs = [] 108 trainable_indices = [] 109 nontrivial_tangents = [] 110 for input_index, tensor in enumerate(inputs): 111 if backprop_util.IsTrainable(tensor): 112 trainable_inputs.append(tensor) 113 trainable_indices.append(input_index) 114 nontrivial_tangents.append(tangents[input_index]) 115 116 with backprop.GradientTape() as transpose_tape: 117 with backprop.GradientTape() as backfunc_tape: 118 backfunc_tape.watch(trainable_inputs) 119 execute.record_gradient(op_name, inputs, attr_tuple, outputs) 120 121 forwardprop_aids = [] 122 trainable_outputs = [] 123 nontrivial_output_indices = [] 124 for output_index, output in enumerate(outputs): 125 if backprop_util.IsTrainable(output): 126 forwardprop_aids.append( 127 array_ops.ones_like(output, name="unused_forwardprop_aid")) 128 trainable_outputs.append(output) 129 nontrivial_output_indices.append(output_index) 130 131 transpose_tape.watch(forwardprop_aids) 132 grads = backfunc_tape.gradient( 133 trainable_outputs, 134 trainable_inputs, 135 forwardprop_aids, 136 unconnected_gradients=UnconnectedGradients.ZERO) 137 nontrivial_output_tangents = transpose_tape.gradient( 138 grads, forwardprop_aids, output_gradients=nontrivial_tangents) 139 output_tangents = [None] * len(outputs) 140 for index, tangent in zip(nontrivial_output_indices, 141 nontrivial_output_tangents): 142 output_tangents[index] = tangent 143 return output_tangents 144 145 146def _jvp_helper_wrapper(op_name, attr_tuple, inputs, outputs, tangents, 147 use_batch): 148 """Computes a batch of Jacobian-vector product for an op. 149 150 Args: 151 op_name: A string, the type of operation being executed. 152 attr_tuple: Attributes of the operation. 153 inputs: A flat list of input Tensors to the operation. 154 outputs: A flat list of output Tensors from the operation. 155 tangents: A flat list of Tensors, compatible with shape `[None] + 156 input_shape`. 157 use_batch: A bool, True to vetorize over batch of tangents of shape `[None] 158 + input_shape`. 159 160 Returns: 161 A flat list of tangents compatible with `outputs` 162 or `[None] + output_shape`. 163 164 Raises: 165 ValueError: if tangent shapes are not compatible with input shapes. 166 """ 167 if use_batch: 168 for primal, tangent in zip(inputs, tangents): 169 if not tangent.shape.is_compatible_with([None] + primal.shape): 170 raise ValueError("Tangent {} was expected to be of shape " 171 "{} but is instead of shape {}".format( 172 tangent, [None] + primal.shape, tangent.shape)) 173 174 return control_flow_ops.vectorized_map( 175 functools.partial(_jvp_helper, op_name, attr_tuple, inputs, outputs), 176 tangents, 177 ) 178 return _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents) 179 180 181# TODO(allenl): experimental_relax_shapes for gradients which rely on static 182# shape information are underspecialized. We may want hand-written forward 183# implementations, or a more satisfying story about how we re-specialize 184# gradients which were traced with relaxed shapes (e.g. use conds instead of 185# trace-time Python logic). 186# 187# Using function.defun rather than def_function.function avoids 188# tf.config.run_functions_eagerly(True). `_jvp_helper` doesn't successfully run 189# eagerly (infinite recursion), and even if it did it would use extra memory and 190# run unnecessary computation. The function does not create variables, so the 191# two symbols are otherwise equivalent. 192_jvp_relaxed_shapes = function.defun( 193 _jvp_helper_wrapper, experimental_relax_shapes=True) 194_jvp_exact_shapes = function.defun( 195 _jvp_helper_wrapper, experimental_relax_shapes=False) 196 197# The maximum number of exact-shape traces to perform for a single op before 198# switching to shape relaxation. 199_TRACE_COUNT_LIMIT = 32 200 201 202def _jvp_dispatch(op_name, 203 attr_tuple, 204 inputs, 205 outputs, 206 tangents, 207 use_batch=False): 208 """Determine which forwardprop function to call.""" 209 # Note that this _TRACE_COUNT read races with writes. That's fine, it just 210 # means we may trace a few more exact shapes before moving on to relaxation. 211 if _TRACE_COUNT.get(op_name, 0) < _TRACE_COUNT_LIMIT: 212 return _jvp_exact_shapes(op_name, attr_tuple, inputs, outputs, tangents, 213 use_batch) 214 return _jvp_relaxed_shapes(op_name, attr_tuple, inputs, outputs, tangents, 215 use_batch) 216 217 218pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch) 219 220 221@tf_export("autodiff.ForwardAccumulator", v1=[]) 222class ForwardAccumulator(): 223 """Computes Jacobian-vector products ("JVP"s) using forward-mode autodiff. 224 225 Compare to `tf.GradientTape` which computes vector-Jacobian products ("VJP"s) 226 using reverse-mode autodiff (backprop). Reverse mode is more attractive when 227 computing gradients of a scalar-valued function with respect to many inputs 228 (e.g. a neural network with many parameters and a scalar loss). Forward mode 229 works best on functions with many outputs and few inputs. Since it does not 230 hold on to intermediate activations, it is much more memory efficient than 231 backprop where it is applicable. 232 233 Consider a simple linear regression: 234 235 >>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]]) 236 >>> targets = tf.constant([[1.], [-1.]]) 237 >>> dense = tf.keras.layers.Dense(1) 238 >>> dense.build([None, 2]) 239 >>> with tf.autodiff.ForwardAccumulator( 240 ... primals=dense.kernel, 241 ... tangents=tf.constant([[1.], [0.]])) as acc: 242 ... loss = tf.reduce_sum((dense(x) - targets) ** 2.) 243 >>> acc.jvp(loss) 244 <tf.Tensor: shape=(), dtype=float32, numpy=...> 245 246 The example has two variables containing parameters, `dense.kernel` (2 247 parameters) and `dense.bias` (1 parameter). Considering the training data `x` 248 as a constant, this means the Jacobian matrix for the function mapping from 249 parameters to loss has one row and three columns. 250 251 With forwardprop, we specify a length-three vector in advance which multiplies 252 the Jacobian. The `primals` constructor argument is the parameter (a 253 `tf.Tensor` or `tf.Variable`) we're specifying a vector for, and the 254 `tangents` argument is the "vector" in Jacobian-vector product. If our goal is 255 to compute the entire Jacobian matrix, forwardprop computes one column at a 256 time while backprop computes one row at a time. Since the Jacobian in the 257 linear regression example has only one row, backprop requires fewer 258 invocations: 259 260 >>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]]) 261 >>> targets = tf.constant([[1.], [-1.]]) 262 >>> dense = tf.keras.layers.Dense(1) 263 >>> dense.build([None, 2]) 264 >>> loss_fn = lambda: tf.reduce_sum((dense(x) - targets) ** 2.) 265 >>> kernel_fprop = [] 266 >>> with tf.autodiff.ForwardAccumulator( 267 ... dense.kernel, tf.constant([[1.], [0.]])) as acc: 268 ... kernel_fprop.append(acc.jvp(loss_fn())) 269 >>> with tf.autodiff.ForwardAccumulator( 270 ... dense.kernel, tf.constant([[0.], [1.]])) as acc: 271 ... kernel_fprop.append(acc.jvp(loss_fn())) 272 >>> with tf.autodiff.ForwardAccumulator(dense.bias, tf.constant([1.])) as acc: 273 ... bias_fprop = acc.jvp(loss_fn()) 274 >>> with tf.GradientTape() as tape: 275 ... loss = loss_fn() 276 >>> kernel_grad, bias_grad = tape.gradient(loss, (dense.kernel, dense.bias)) 277 >>> np.testing.assert_allclose( 278 ... kernel_grad, tf.stack(kernel_fprop)[:, tf.newaxis]) 279 >>> np.testing.assert_allclose(bias_grad, bias_fprop[tf.newaxis]) 280 281 Implicit in the `tape.gradient` call is a length-one vector which 282 left-multiplies the Jacobian, a vector-Jacobian product. 283 284 `ForwardAccumulator` maintains JVPs corresponding primal tensors it is 285 watching, derived from the original `primals` specified in the constructor. As 286 soon as a primal tensor is deleted, `ForwardAccumulator` deletes the 287 corresponding JVP. 288 289 `acc.jvp(x)` retrieves `acc`'s JVP corresponding to the primal tensor `x`. It 290 does not perform any computation. `acc.jvp` calls can be repeated as long as 291 `acc` is accessible, whether the context manager is active or not. New JVPs 292 are only computed while the context manager is active. 293 294 Note that `ForwardAccumulator`s are always applied in the order their context 295 managers were entered, so inner accumulators will not see JVP computation from 296 outer accumulators. Take higher-order JVPs from outer accumulators: 297 298 >>> primal = tf.constant(1.1) 299 >>> with tf.autodiff.ForwardAccumulator(primal, tf.constant(1.)) as outer: 300 ... with tf.autodiff.ForwardAccumulator(primal, tf.constant(1.)) as inner: 301 ... primal_out = primal ** tf.constant(3.5) 302 >>> inner_jvp = inner.jvp(primal_out) 303 >>> inner_jvp # 3.5 * 1.1 ** 2.5 304 <tf.Tensor: shape=(), dtype=float32, numpy=4.4417057> 305 >>> outer.jvp(inner_jvp) # 3.5 * 2.5 * 1.1 ** 1.5 306 <tf.Tensor: shape=(), dtype=float32, numpy=10.094786> 307 308 Reversing the collection in the last line to instead retrieve 309 `inner.jvp(outer.jvp(primal_out))` will not work. 310 311 Strict nesting also applies to combinations of `ForwardAccumulator` and 312 `tf.GradientTape`. More deeply nested `GradientTape` objects will ignore the 313 products of outer `ForwardAccumulator` objects. This allows (for example) 314 memory-efficient forward-over-backward computation of Hessian-vector products, 315 where the inner `GradientTape` would otherwise hold on to all intermediate 316 JVPs: 317 318 >>> v = tf.Variable([1., 2.]) 319 >>> with tf.autodiff.ForwardAccumulator( 320 ... v, 321 ... # The "vector" in Hessian-vector product. 322 ... tf.constant([1., 0.])) as acc: 323 ... with tf.GradientTape() as tape: 324 ... y = tf.reduce_sum(v ** 3.) 325 ... backward = tape.gradient(y, v) 326 >>> backward # gradient from backprop 327 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 3., 12.], dtype=float32)> 328 >>> acc.jvp(backward) # forward-over-backward Hessian-vector product 329 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 0.], dtype=float32)> 330 """ 331 332 def __init__(self, primals, tangents): 333 """Specify tensors to watch and their Jacobian-vector products. 334 335 Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix 336 (a Jacobian-vector product) for the function computed while this accumulator 337 is active. Since JVPs are computed in forward mode as the computation 338 happens, this vector must be supplied in advance. 339 340 Listing a single tensor multiple times in `primals` raises an 341 exception. Excluding a tensor from `primals` is equivalent to watching it 342 with a tangent tensor of zeros. 343 344 Args: 345 primals: A tensor or nested structure of tensors to watch. 346 tangents: A tensor or nested structure of tensors, with the same nesting 347 structure as `primals`, with each element being a vector with the same 348 size as the corresponding primal element. 349 350 Raises: 351 ValueError: If the same tensor or variable is specified multiple times in 352 `primals`. 353 """ 354 self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(False) 355 self._recording = False 356 primal_ids = set() 357 for primal in nest.flatten(primals): 358 if id(primal) in primal_ids: 359 raise ValueError( 360 "Tensor {} was specified as a primal multiple times. This may " 361 "indicate an error. If it was intended, please sum the " 362 "corresponding tangents.") 363 primal_ids.add(id(primal)) 364 self._watch(primals, tangents) 365 366 def __enter__(self): 367 self._push_accumulator() 368 return self 369 370 def __exit__(self, typ, value, traceback): 371 if self._recording: 372 self._pop_accumulator() 373 374 def _push_accumulator(self): 375 if self._recording: 376 raise ValueError("Accumulator is already recording.") 377 pywrap_tfe.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator) 378 self._recording = True 379 380 def _pop_accumulator(self): 381 if not self._recording: 382 raise ValueError("Accumulator is not recording.") 383 pywrap_tfe.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator) 384 self._recording = False 385 386 def _watch(self, primals, tangents): 387 """Ensures that `primals` are being traced by this accumulator. 388 389 Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix 390 (a Jacobian-vector product) for the function computed while this accumulator 391 is active. Since JVPs are computed in forward mode as the computation 392 happens, this vector must be supplied in advance. 393 394 Watching a single tensor multiple times sums each of its `tangents`. Any 395 un-watched tensor has zeros for its tangent vector. 396 397 Args: 398 primals: A Tensor or list of Tensors. 399 tangents: A Tensor or list of Tensors matching `primals`. 400 """ 401 402 def _watch(primal, tangent): 403 if not primal.dtype.is_floating: 404 logging.log_first_n( 405 logging.WARN, "The dtype of the watched primal must be " 406 "floating (e.g. tf.float32), got %r", 5, primal.dtype) 407 tangent = ops.convert_to_tensor(tangent, dtype=primal.dtype) 408 if hasattr(primal, "handle"): 409 # Run convert_to_tensor to get the captured handle from whichever 410 # function we're running if necessary. 411 primal = ops.convert_to_tensor(primal.handle) 412 pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator, primal, 413 tangent) 414 415 nest.map_structure(_watch, primals, tangents, expand_composites=True) 416 417 def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE): 418 """Fetches the Jacobian-vector product computed for `primals`. 419 420 Note that this method performs no computation, and simply looks up a JVP 421 that was already computed (unlike backprop using a `tf.GradientTape`, where 422 the computation happens on the call to `tape.gradient`). 423 424 Args: 425 primals: A watched Tensor or structure of Tensors to fetch the JVPs for. 426 unconnected_gradients: A value which can either hold 'none' or 'zero' and 427 alters the value which will be returned if no JVP was computed for 428 `primals`. The possible values and effects are detailed in 429 'tf.UnconnectedGradients' and it defaults to 'none'. 430 431 Returns: 432 Tensors with the same shapes and dtypes as `primals`, or None if no JVP 433 is available. 434 """ 435 unconnected_gradients = UnconnectedGradients(unconnected_gradients) 436 if self._accumulator is None: 437 raise ValueError("Called jvp() without first tracing anything.") 438 439 def _fetch_jvp(tensor): 440 if hasattr(tensor, "handle"): 441 unwrapped_tensor = ops.convert_to_tensor(tensor.handle) 442 else: 443 unwrapped_tensor = tensor 444 result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator, 445 unwrapped_tensor) 446 if result is None and unconnected_gradients == UnconnectedGradients.ZERO: 447 result = array_ops.zeros_like(tensor) 448 return result 449 450 return nest.map_structure(_fetch_jvp, primals) 451 452 @classmethod 453 def _batch_accumulator(cls, primals, tangents): 454 """Factory constructor to test accumulator on batches of tangents. 455 456 Args: 457 primals: A tensor or nested structure of tensors to watch. 458 tangents: A tensor or nested structure of tensors, with the same nesting 459 structure as `primals`, with each element being a vector with compatible 460 shape `[None] + primal.shape` of the corresponding primal element. 461 462 Returns: 463 A batch accumulator object. 464 """ 465 acc = super(ForwardAccumulator, cls).__new__(cls, primals, tangents) 466 acc._recording = False 467 acc._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(True) 468 primal_ids = set() 469 for primal, tangent in zip(nest.flatten(primals), nest.flatten(tangents)): 470 tangent.shape.assert_is_compatible_with( 471 tensor_shape.TensorShape([None]) + primal.shape) 472 if id(primal) in primal_ids: 473 raise ValueError( 474 "Tensor {} was specified as a primal multiple times. This may " 475 "indicate an error. If it was intended, please sum the " 476 "corresponding tangents.") 477 primal_ids.add(id(primal)) 478 acc._watch(primals, tangents) 479 return acc 480