1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Code for backpropagation using the tape utilities.""" 16 17# TODO(b/159343581): Properly support CompositeTensor in all functions in this 18# file. 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import functools 25import operator 26import sys 27 28import six 29 30from tensorflow.python import pywrap_tfe 31from tensorflow.python.eager import backprop_util 32from tensorflow.python.eager import context 33from tensorflow.python.eager import execute 34from tensorflow.python.eager import imperative_grad 35from tensorflow.python.eager import tape 36from tensorflow.python.framework import constant_op 37from tensorflow.python.framework import dtypes 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import tensor_shape 40from tensorflow.python.framework import tensor_util 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import check_ops 43from tensorflow.python.ops import control_flow_util 44from tensorflow.python.ops import default_gradient 45from tensorflow.python.ops import gen_array_ops 46from tensorflow.python.ops import gen_math_ops 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import resource_variable_ops 49from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 50from tensorflow.python.platform import tf_logging as logging 51from tensorflow.python.util import _pywrap_utils 52from tensorflow.python.util import nest 53from tensorflow.python.util import tf_contextlib 54from tensorflow.python.util import tf_inspect 55from tensorflow.python.util.lazy_loader import LazyLoader 56from tensorflow.python.util.tf_export import tf_export 57 58 59# Note that we need to lazy load the following two modules to avoid creating 60# circular dependencies. 61# TODO(b/119775953): fix the circular dependencies. 62pfor_ops = LazyLoader( 63 "pfor_ops", globals(), 64 "tensorflow.python.ops.parallel_for.control_flow_ops") 65 66function = LazyLoader("function", globals(), 67 "tensorflow.python.eager.function") 68 69_op_attr_type_cache = {} 70 71 72def op_attr_type(op_type, attr_name): 73 try: 74 return _op_attr_type_cache[(op_type, attr_name)] 75 except KeyError: 76 context.ensure_initialized() 77 h = context.context()._handle # pylint: disable=protected-access 78 attr_type = pywrap_tfe.TFE_OpNameGetAttrType(h, op_type, attr_name) 79 _op_attr_type_cache[(op_type, attr_name)] = attr_type 80 return attr_type 81 82 83def make_attr(attr_type, value): 84 # pybind11 enums do not return the raw value like SWIG enums do. They are 85 # useful when comparing amongst each other but not direct integers as we are 86 # doing in most tests. 87 # https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types 88 # TODO(amitpatankar): After all SWIG transitions, convert the enum comparisons 89 # from integer value to class. 90 if attr_type == int(pywrap_tfe.TF_ATTR_TYPE): 91 return dtypes.as_dtype(value) 92 if attr_type == [int(pywrap_tfe.TF_ATTR_TYPE)]: 93 return [dtypes.as_dtype(v) for v in value] 94 if attr_type == int(pywrap_tfe.TF_ATTR_SHAPE): 95 return tensor_shape.as_shape(value).as_proto() 96 if attr_type == [int(pywrap_tfe.TF_ATTR_SHAPE)]: 97 return [tensor_shape.as_shape(v).as_proto() for v in value] 98 if isinstance(value, str): 99 return value.encode() 100 return value 101 102 103class _MockOp(object): 104 """Pretends to be a tf.Operation for the gradient functions.""" 105 106 def __init__(self, attrs, inputs, outputs, typ, skip_input_indices): 107 self.attrs = attrs 108 self.inputs = inputs 109 self.outputs = outputs 110 self.type = typ 111 self.skip_input_indices = skip_input_indices 112 113 def get_attr(self, attr): 114 typ = op_attr_type(self.type, attr) 115 for i in range(0, len(self.attrs), 2): 116 if self.attrs[i] == attr: 117 return make_attr(typ, self.attrs[i + 1]) 118 raise KeyError(attr) 119 120 def _get_control_flow_context(self): 121 raise NotImplementedError( 122 "tf.GradientTape.gradients() does not support graph control flow " 123 "operations like tf.cond or tf.while at this time. Use tf.gradients() " 124 "instead. If you need this feature, please file a feature request at " 125 "https://github.com/tensorflow/tensorflow/issues/new" 126 ) 127 128 129def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs, 130 out_grads, skip_input_indices, forward_pass_name_scope): 131 """Calls the gradient function of the op. 132 133 Args: 134 op_name: the name of the op to be differentiated. 135 attr_tuple: the attrs, as a tuple. 136 num_inputs: the number of inputs to the op. 137 inputs: inputs to the original operation. 138 outputs: outputs to the original operation. 139 out_grads: gradients of the operation wrt its outputs. 140 skip_input_indices: a tuple that is passed to the gradient function, 141 indicating which inputs to skip calculating the gradient for 142 forward_pass_name_scope: the namescope of the op in the forward pass. 143 144 Returns: 145 The gradients with respect to the inputs of the function, as a list. 146 """ 147 mock_op = _MockOp(attr_tuple, inputs, outputs, op_name, skip_input_indices) 148 grad_fn = ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access 149 if grad_fn is None: 150 return [None] * num_inputs 151 152 # This does not work with v1 TensorArrays. 153 if ops.executing_eagerly_outside_functions( 154 ) or control_flow_util.EnableControlFlowV2(ops.get_default_graph()): 155 gradient_name_scope = "gradient_tape/" 156 if forward_pass_name_scope: 157 gradient_name_scope += forward_pass_name_scope + "/" 158 with ops.name_scope(gradient_name_scope): 159 return grad_fn(mock_op, *out_grads) 160 else: 161 return grad_fn(mock_op, *out_grads) 162 163 164pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function) 165 166 167def _must_record_gradient(): 168 return not pywrap_tfe.TFE_Py_TapeSetIsEmpty() 169 170 171@tf_export("__internal__.record_gradient", v1=[]) 172def record_gradient(op_name, inputs, attrs, outputs): 173 """Explicitly record the gradient for a given op. 174 175 Args: 176 op_name: The op name as listed in the `OpDef` for the op. 177 inputs: A list of tensor inputs to the op. 178 attrs: The op attributes as a flattened list of alternating attribute names 179 and attribute values. 180 outputs: A list of tensor outputs from the op. 181 """ 182 pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, outputs, 183 ops.get_name_scope()) 184 185 186execute.must_record_gradient = _must_record_gradient 187execute.record_gradient = record_gradient 188 189 190def implicit_val_and_grad(f): 191 """Returns a function which differentiates f with respect to variables. 192 193 The wrapped function returns the value and the gradient of f when called with 194 the same arguments. The gradient is with respect to all trainable TFE 195 variables accessed by `f`. 196 197 This function is useful when the exact set of variables to differentiate with 198 is not known ahead of time. 199 200 Example: 201 202 ```python 203 dense_layer = tf.compat.v1.layers.Dense(1) 204 def loss(x, y): 205 return tf.reduce_sum(tf.square(dense_layer(x) - y)) 206 207 # Obtain the gradient function. 208 val_grad_fn = tfe.implicit_value_and_gradients(loss) 209 210 # Invoke the gradient function with concrete values of x and y. 211 x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 212 y = tf.constant([[10.0], [20.0]]) 213 value, grads_and_vars = val_grad_fn(x, y) 214 print('Value of loss: %s' % value) 215 216 # Apply the gradients to Variables. 217 optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1) 218 optimizer.apply_gradients(grads_and_vars) 219 ``` 220 221 Args: 222 f: function to be differentiated. If `f` returns a scalar, this scalar will 223 be differentiated. If `f` returns a tensor or list of tensors, by default 224 a scalar will be computed by adding all their values to produce a single 225 scalar. 226 227 Returns: 228 A function which, when called, returns a tuple pair. 229 Its first element is the value to which the function evaluates. 230 Its second element is list of (gradient, variable) pairs. 231 232 Raises: 233 ValueError: if `f` returns None. 234 """ 235 # TODO(cais): Remove calls to tf.constant() once the gradients functions 236 # accept lists and np.ndarrays. 237 238 def grad_fn(*args, **kwds): 239 """Computes the gradient of the wrapped function.""" 240 this_tape = tape.push_new_tape() 241 try: 242 end_node = f(*args, **kwds) 243 if end_node is None: 244 raise ValueError("Cannot differentiate a function that returns None; " 245 "did you forget to return a value from {}?".format( 246 f.__name__)) 247 finally: 248 tape.pop_tape(this_tape) 249 # Note: variables are returned in construction order. This ensures unique 250 # order across executions. 251 variables = this_tape.watched_variables() 252 if not variables: 253 raise ValueError("No trainable variables were accessed while the " 254 "function was being computed.") 255 256 sources = [v.handle for v in variables] 257 for s in sources: 258 if getattr(s, "is_packed", False): 259 raise ValueError( 260 "GradientTape.gradient is not supported on packed EagerTensors yet." 261 ) 262 grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node), 263 sources) 264 return end_node, list(zip(grad, variables)) 265 266 return grad_fn 267 268 269def implicit_grad(f): 270 """Returns a function which differentiates f with respect to variables. 271 272 The wrapped function returns the gradient of f when called with the same 273 arguments. The gradient is with respect to all trainable TFE variables 274 accessed by `f`. 275 276 This function is useful when the exact set of variables to differentiate with 277 is not known ahead of time. 278 279 Example: 280 281 ```python 282 dense_layer = tf.compat.v1.layers.Dense(1) 283 def loss(x, y): 284 return tf.reduce_sum(tf.square(dense_layer(x) - y)) 285 286 # Obtain the gradient function. 287 grad_fn = tfe.implicit_gradients(loss) 288 289 # Invoke the gradient function with concrete values of x and y. 290 x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 291 y = tf.constant([[10.0], [20.0]]) 292 grads_and_vars = grad_fn(x, y) 293 294 # Apply the gradients to Variables. 295 optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1) 296 optimizer.apply_gradients(grads_and_vars) 297 ``` 298 299 Args: 300 f: function to be differentiated. If `f` returns a scalar, this scalar will 301 be differentiated. If `f` returns a tensor or list of tensors, by default 302 a scalar will be computed by adding all their values to produce a single 303 scalar. 304 305 Returns: 306 A function which, when called, returns a list of (gradient, variable) pairs. 307 """ 308 # TODO(cais): Remove calls to tf.constant() once the gradients functions 309 # accept lists and np.ndarrays. 310 311 def grad_fn(*args, **kwds): 312 """Computes the gradient of the wrapped function.""" 313 return implicit_val_and_grad(f)(*args, **kwds)[1] 314 315 return grad_fn 316 317 318def _get_arg_spec(f, params, param_args): 319 """The positions of the parameters of f to be differentiated in param_args.""" 320 try: 321 args = tf_inspect.getfullargspec(f).args 322 except TypeError as e: 323 # TypeError can happen when f is a callable object. 324 if params is None: 325 return range(len(param_args)) 326 elif all(isinstance(x, int) for x in params): 327 return params 328 raise ValueError("Either callable provided is not a function or could not " 329 "inspect its arguments by name: %s. Original error: %s" 330 % (f, e)) 331 if params is None: 332 if not args: 333 return range(len(param_args)) 334 if args[0] == "self": 335 return range(len(args) - 1) 336 else: 337 return range(len(args)) 338 elif all(isinstance(x, six.string_types) for x in params): 339 return [args.index(n) for n in params] 340 elif all(isinstance(x, int) for x in params): 341 return params 342 else: 343 raise ValueError( 344 "params must be all strings or all integers; got %s." % params) 345 346 347def gradients_function(f, params=None): 348 """Returns a function which differentiates f with respect to params. 349 350 Example: 351 ```python 352 # f(x, y) = (x ^ 3) * y - x * (y ^ 2) 353 # Therefore, the 1st order derivatives are: 354 # df / dx = 3 * (x ^ 2) * y - y ^ 2 355 # df / dy = x ^ 3 - 2 * x * y 356 # The 2nd order derivatives with respect to x is: 357 # d^2 f / (dx)^2 = 6 * x * y 358 def f(x, y): 359 return x * x * x * y - x * y * y 360 361 # Obtain a function that returns 1st order gradients. 362 grad_fn = tfe.gradients_function(f) 363 364 x = 2.0 365 y = 3.0 366 367 # Invoke the 1st order gradient function. 368 x_grad, y_grad = grad_fn(x, y) 369 assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2 370 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 371 372 # Obtain a function that returns the 2nd order gradient with respect to x. 373 gradgrad_fn = tfe.gradients_function(lambda x, y: grad_fn(x, y)[0]) 374 375 # Invoke the 2nd order gradient function. 376 x_gradgrad = gradgrad_fn(x, y)[0] 377 assert x_gradgrad.numpy() == 6 * 2 * 3 378 379 # To obtain a callable that returns the gradient(s) of `f` with respect to a 380 # subset of its inputs, use the `params` keyword argument with 381 # `gradients_function()`. 382 ygrad_fn = tfe.gradients_function(f, params=[1]) 383 384 (y_grad,) = ygrad_fn(x, y) 385 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 386 ``` 387 388 Note that only tensors with real or complex dtypes are differentiable. 389 390 Args: 391 f: function to be differentiated. If `f` returns a scalar, this scalar will 392 be differentiated. If `f` returns a tensor or list of tensors, by default 393 a scalar will be computed by adding all their values to produce a single 394 scalar. If desired, the tensors can be elementwise multiplied by the 395 tensors passed as the `dy` keyword argument to the returned gradient 396 function. 397 params: list of parameter names of f or list of integers indexing the 398 parameters with respect to which we'll differentiate. Passing None 399 differentiates with respect to all parameters. 400 401 Returns: 402 function which, when called, returns the value of f and the gradient 403 of `f` with respect to all of `params`. The function takes an extra optional 404 keyword argument `dy`. Setting it allows computation of vector jacobian 405 products for vectors other than the vector of ones. 406 407 Raises: 408 ValueError: if the params are not all strings or all integers. 409 """ 410 411 def decorated(*args, **kwds): 412 """Computes the gradient of the decorated function.""" 413 414 _, grad = val_and_grad_function(f, params=params)(*args, **kwds) 415 return grad 416 417 return decorated 418 419 420def _ensure_unique_tensor_objects(parameter_positions, args): 421 """Make each of the parameter_positions in args a unique ops.Tensor object. 422 423 Ensure that each parameter is treated independently. 424 For example: 425 426 def f(x, y): return x * y 427 g = gradients_function(f) 428 one = tf.constant(1.) 429 430 g(one, one) should return [1., 1.] 431 (even though the two arguments are the same Tensor object). 432 433 Args: 434 parameter_positions: List of indices into args defining the arguments to 435 differentiate against. 436 args: A list of arguments to the function to be differentiated. 437 438 Returns: 439 args, possibly edited in-place. 440 """ 441 s = set() 442 for (i, t) in enumerate(args): 443 if i in parameter_positions: 444 tid = ops.tensor_id(t) 445 if tid in s: 446 args[i] = gen_array_ops.identity(args[i]) 447 else: 448 s.add(tid) 449 return args 450 451 452def val_and_grad_function(f, params=None): 453 """Returns a function that computes f and its derivative w.r.t. params. 454 455 Example: 456 ```python 457 # f(x, y) = (x ^ 3) * y - x * (y ^ 2) 458 # Therefore, the 1st order derivatives are: 459 # df / dx = 3 * (x ^ 2) * y - y ^ 2 460 # df / dy = x ^ 3 - 2 * x * y 461 def f(x, y): 462 return x * x * x * y - x * y * y 463 464 # Obtain a function that returns the function value and the 1st order 465 # gradients. 466 val_grads_fn = tfe.value_and_gradients_function(f) 467 468 x = 2.0 469 y = 3.0 470 471 # Invoke the value-and-gradients function. 472 f_val, (x_grad, y_grad) = val_grads_fn(x, y) 473 assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2) 474 assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2 475 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 476 477 # To obtain a callable that returns the value of `f` and the gradient(s) of 478 # `f` with respect to a subset of its inputs, use the `params` keyword 479 # argument with `value_and_gradients_function()`. 480 val_ygrad_fn = tfe.value_and_gradients_function(f, params=[1]) 481 482 f_val, (y_grad,) = val_ygrad_fn(x, y) 483 assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2) 484 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 485 ``` 486 487 Args: 488 f: function to be differentiated. If `f` returns a scalar, this scalar will 489 be differentiated. If `f` returns a tensor or list of tensors, by default 490 a scalar will be computed by adding all their values to produce a single 491 scalar. If desired, the tensors can be elementwise multiplied by the 492 tensors passed as the `dy` keyword argument to the returned gradient 493 function. 494 params: list of parameter names of f or list of integers indexing the 495 parameters with respect to which we'll differentiate. Passing `None` 496 differentiates with respect to all parameters. 497 498 Returns: 499 function which, when called, returns the value of f and the gradient 500 of f with respect to all of `params`. The function takes an extra optional 501 keyword argument "dy". Setting it allows computation of vector jacobian 502 products for vectors other than the vector of ones. 503 504 Raises: 505 ValueError: if the params are not all strings or all integers. 506 """ 507 508 def decorated(*args, **kwds): 509 """Computes the value and gradient of the decorated function.""" 510 dy = kwds.pop("dy", None) 511 if kwds: 512 raise ValueError("Functions to be differentiated cannot " 513 "receive keyword arguments.") 514 val, vjp = make_vjp(f, params)(*args, **kwds) 515 return val, vjp(dy=dy) 516 517 return decorated 518 519 520def make_vjp(f, params=None, persistent=True): 521 """Returns a function that computes f and its vjp w.r.t. 522 523 params. 524 525 The term "vjp" here is an abbreviation for vector-jacobian product. 526 527 Args: 528 f: the function to be differentiated. 529 params: the parameters (numbers or names) to differentiate with respect to. 530 A value of None will differentiate with respect to all parameters. 531 persistent: Boolean controlling whether the VJP function can be re-used. 532 Must be True or False. 533 534 Returns: 535 A function, which when called, returns a tuple (value, vjp), where: 536 - value is the result of calling f. 537 - vjp is a function, which takes a vector as an argument and 538 returns the product of that vector with the Jacobian of f. 539 Providing no argument to vjp is equivalent to providing a 540 vector of ones. 541 542 For example, 543 ```python 544 def f(x): 545 return x * x 546 547 wrapped_fn = tfe.make_vjp(f) 548 result, vjp = wrapped_fn(tf.constant(3.0)) 549 # result is 9.0 550 vjp() # the vjp function returns 6.0 551 552 Raises: 553 ValueError: if `f` returns None. 554 """ 555 556 def decorated(*args, **kwds): 557 """Computes the value and gradient of the decorated function.""" 558 parameter_positions = _get_arg_spec(f, params, args) 559 assert not kwds, "The gradient function can't take keyword arguments." 560 this_tape = tape.push_new_tape(persistent=persistent) 561 try: 562 sources = [] 563 args = [ 564 ops.convert_to_tensor(arg) if i in parameter_positions else arg 565 for i, arg in enumerate(args) 566 ] 567 args = _ensure_unique_tensor_objects(parameter_positions, args) 568 for i in parameter_positions: 569 if getattr(args[i], "is_packed", False): 570 raise ValueError( 571 "GradientTape.gradient is not supported on packed EagerTensors" 572 "yet.") 573 sources.append(args[i]) 574 tape.watch(this_tape, args[i]) 575 result = f(*args) 576 if result is None: 577 raise ValueError("Cannot differentiate a function that returns None; " 578 "did you forget to return a value from {}?".format( 579 f.__name__)) 580 flat_result = nest.flatten(result) 581 flat_result = [gen_array_ops.identity(x) for x in flat_result] 582 result = nest.pack_sequence_as(result, flat_result) 583 finally: 584 tape.pop_tape(this_tape) 585 def vjp(dy=None): 586 if dy is not None: 587 dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)] 588 return imperative_grad.imperative_grad( 589 this_tape, nest.flatten(result), sources, output_gradients=dy) 590 591 return result, vjp 592 593 return decorated 594 595 596def flatten_nested_indexed_slices(grad): 597 assert isinstance(grad, ops.IndexedSlices) 598 if isinstance(grad.values, ops.Tensor): 599 return grad 600 else: 601 assert isinstance(grad.values, ops.IndexedSlices) 602 g = flatten_nested_indexed_slices(grad.values) 603 return ops.IndexedSlices(g.values, array_ops.gather(grad.indices, 604 g.indices), 605 g.dense_shape) 606 607 608def aggregate_indexed_slices_gradients(grads): 609 """Aggregates gradients containing `IndexedSlices`s.""" 610 if len(grads) < 1: 611 return None 612 if len(grads) == 1: 613 return grads[0] 614 grads = [g for g in grads if g is not None] 615 # If any gradient is a `Tensor`, sum them up and return a dense tensor 616 # object. 617 if any(isinstance(g, ops.Tensor) for g in grads): 618 return math_ops.add_n(grads) 619 620 # The following `_as_indexed_slices_list` casts ids of IndexedSlices into 621 # int64. It is to make sure the inputs of `concat` all have same the data 622 # type. 623 grads = math_ops._as_indexed_slices_list(grads) # pylint: disable=protected-access 624 625 grads = [flatten_nested_indexed_slices(x) for x in grads] 626 # Form IndexedSlices out of the concatenated values and indices. 627 concat_grad = ops.IndexedSlices( 628 array_ops.concat([x.values for x in grads], axis=0), 629 array_ops.concat([x.indices for x in grads], axis=0), 630 grads[0].dense_shape) 631 632 return concat_grad 633 634 635def _aggregate_grads(gradients): 636 """Aggregate gradients from multiple sources. 637 638 Args: 639 gradients: A list of 'Tensor' or 'IndexedSlices' gradients. 640 641 Returns: 642 If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'. 643 Otherwise returns an aggregated 'IndexedSlices'. 644 """ 645 assert gradients, "No gradients to aggregate" 646 647 if len(gradients) == 1: 648 return gradients[0] 649 if all(isinstance(g, ops.Tensor) for g in gradients): 650 return gen_math_ops.add_n(gradients) 651 else: 652 assert all(isinstance(g, (ops.Tensor, ops.IndexedSlices)) 653 for g in gradients) 654 return aggregate_indexed_slices_gradients(gradients) 655 656 657def _num_elements(grad): 658 """The number of elements in the `grad` tensor.""" 659 if isinstance(grad, ops.Tensor): 660 shape_tuple = grad._shape_tuple() # pylint: disable=protected-access 661 elif isinstance(grad, ops.IndexedSlices): 662 shape_tuple = grad.values._shape_tuple() # pylint: disable=protected-access 663 else: 664 raise ValueError("`grad` not a Tensor or IndexedSlices.") 665 if shape_tuple is None or None in shape_tuple: 666 return 0 667 return functools.reduce(operator.mul, shape_tuple, 1) 668 669 670def _fast_fill(value, shape, dtype): 671 return array_ops.fill( 672 constant_op.constant(shape, dtype=dtypes.int32), 673 constant_op.constant(value, dtype=dtype)) 674 675 676def _zeros(shape, dtype): 677 """Helper to return (possibly cached) zero tensors in eager mode.""" 678 # Note: variants will use _zeros_like 679 if dtype == dtypes.string or dtype == dtypes.resource: 680 return None 681 682 ctx = context.context() 683 if not ctx.executing_eagerly(): 684 return array_ops.zeros(shape, dtype) 685 686 device = ctx.device_name 687 688 if tensor_util.is_tf_type(shape): 689 shape_key = shape.ref() 690 else: 691 shape_key = shape 692 cache_key = shape_key, dtype, device 693 cached = ctx.zeros_cache().get(cache_key) 694 if cached is None: 695 if dtypes.as_dtype(dtype).is_bool: 696 value = False 697 else: 698 value = 0 699 cached = _fast_fill(value, shape, dtype) 700 ctx.zeros_cache().put(cache_key, cached) 701 return cached 702 703 704def _ones(shape, dtype): 705 as_dtype = dtypes.as_dtype(dtype) 706 if as_dtype == dtypes.string: 707 return None 708 709 if not context.executing_eagerly(): 710 return array_ops.ones(shape, dtype) 711 712 if as_dtype.is_bool: 713 value = True 714 else: 715 value = 1 716 717 if shape == (): # pylint: disable=g-explicit-bool-comparison 718 return constant_op.constant(value, dtype=dtype) 719 return _fast_fill(value, shape, dtype) 720 721 722_default_vspace = imperative_grad.VSpace( 723 num_elements_fn=_num_elements, 724 aggregate_fn=_aggregate_grads, 725 zeros_fn=_zeros, 726 ones_fn=_ones, 727 zeros_like_fn=default_gradient.zeros_like, 728 ones_like_fn=default_gradient.ones_like, 729 graph_shape_fn=gen_array_ops.shape) 730pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace) 731 732 733def _handle_or_self(x): 734 """Unwrap resource variable/ndarray to return tensors.""" 735 if resource_variable_ops.is_resource_variable(x): 736 return x.handle 737 return x 738 739 740@tf_export("GradientTape", "autodiff.GradientTape", v1=["GradientTape"]) 741class GradientTape(object): 742 """Record operations for automatic differentiation. 743 744 Operations are recorded if they are executed within this context manager and 745 at least one of their inputs is being "watched". 746 747 Trainable variables (created by `tf.Variable` or `tf.compat.v1.get_variable`, 748 where `trainable=True` is default in both cases) are automatically watched. 749 Tensors can be manually watched by invoking the `watch` method on this context 750 manager. 751 752 For example, consider the function `y = x * x`. The gradient at `x = 3.0` can 753 be computed as: 754 755 >>> x = tf.constant(3.0) 756 >>> with tf.GradientTape() as g: 757 ... g.watch(x) 758 ... y = x * x 759 >>> dy_dx = g.gradient(y, x) 760 >>> print(dy_dx) 761 tf.Tensor(6.0, shape=(), dtype=float32) 762 763 GradientTapes can be nested to compute higher-order derivatives. For example, 764 765 >>> x = tf.constant(5.0) 766 >>> with tf.GradientTape() as g: 767 ... g.watch(x) 768 ... with tf.GradientTape() as gg: 769 ... gg.watch(x) 770 ... y = x * x 771 ... dy_dx = gg.gradient(y, x) # dy_dx = 2 * x 772 >>> d2y_dx2 = g.gradient(dy_dx, x) # d2y_dx2 = 2 773 >>> print(dy_dx) 774 tf.Tensor(10.0, shape=(), dtype=float32) 775 >>> print(d2y_dx2) 776 tf.Tensor(2.0, shape=(), dtype=float32) 777 778 By default, the resources held by a GradientTape are released as soon as 779 GradientTape.gradient() method is called. To compute multiple gradients over 780 the same computation, create a persistent gradient tape. This allows multiple 781 calls to the gradient() method as resources are released when the tape object 782 is garbage collected. For example: 783 784 >>> x = tf.constant(3.0) 785 >>> with tf.GradientTape(persistent=True) as g: 786 ... g.watch(x) 787 ... y = x * x 788 ... z = y * y 789 >>> dz_dx = g.gradient(z, x) # (4*x^3 at x = 3) 790 >>> print(dz_dx) 791 tf.Tensor(108.0, shape=(), dtype=float32) 792 >>> dy_dx = g.gradient(y, x) 793 >>> print(dy_dx) 794 tf.Tensor(6.0, shape=(), dtype=float32) 795 796 By default GradientTape will automatically watch any trainable variables that 797 are accessed inside the context. If you want fine grained control over which 798 variables are watched you can disable automatic tracking by passing 799 `watch_accessed_variables=False` to the tape constructor: 800 801 >>> x = tf.Variable(2.0) 802 >>> w = tf.Variable(5.0) 803 >>> with tf.GradientTape( 804 ... watch_accessed_variables=False, persistent=True) as tape: 805 ... tape.watch(x) 806 ... y = x ** 2 # Gradients will be available for `x`. 807 ... z = w ** 3 # No gradients will be available as `w` isn't being watched. 808 >>> dy_dx = tape.gradient(y, x) 809 >>> print(dy_dx) 810 tf.Tensor(4.0, shape=(), dtype=float32) 811 >>> # No gradients will be available as `w` isn't being watched. 812 >>> dz_dy = tape.gradient(z, w) 813 >>> print(dz_dy) 814 None 815 816 Note that when using models you should ensure that your variables exist when 817 using `watch_accessed_variables=False`. Otherwise it's quite easy to make your 818 first iteration not have any gradients: 819 820 ```python 821 a = tf.keras.layers.Dense(32) 822 b = tf.keras.layers.Dense(32) 823 824 with tf.GradientTape(watch_accessed_variables=False) as tape: 825 tape.watch(a.variables) # Since `a.build` has not been called at this point 826 # `a.variables` will return an empty list and the 827 # tape will not be watching anything. 828 result = b(a(inputs)) 829 tape.gradient(result, a.variables) # The result of this computation will be 830 # a list of `None`s since a's variables 831 # are not being watched. 832 ``` 833 834 Note that only tensors with real or complex dtypes are differentiable. 835 """ 836 837 def __init__(self, persistent=False, watch_accessed_variables=True): 838 """Creates a new GradientTape. 839 840 Args: 841 persistent: Boolean controlling whether a persistent gradient tape 842 is created. False by default, which means at most one call can 843 be made to the gradient() method on this object. 844 watch_accessed_variables: Boolean controlling whether the tape will 845 automatically `watch` any (trainable) variables accessed while the tape 846 is active. Defaults to True meaning gradients can be requested from any 847 result computed in the tape derived from reading a trainable `Variable`. 848 If False users must explicitly `watch` any `Variable`s they want to 849 request gradients from. 850 """ 851 self._tape = None 852 self._persistent = persistent 853 self._watch_accessed_variables = watch_accessed_variables 854 self._watched_variables = () 855 self._recording = False 856 857 def __enter__(self): 858 """Enters a context inside which operations are recorded on this tape.""" 859 self._push_tape() 860 return self 861 862 def __exit__(self, typ, value, traceback): 863 """Exits the recording context, no further operations are traced.""" 864 if self._recording: 865 self._pop_tape() 866 867 def _push_tape(self): 868 """Pushes a new tape onto the tape stack.""" 869 if self._recording: 870 raise ValueError("Tape is still recording, This can happen if you try to " 871 "re-enter an already-active tape.") 872 if self._tape is None: 873 self._tape = tape.push_new_tape( 874 persistent=self._persistent, 875 watch_accessed_variables=self._watch_accessed_variables) 876 else: 877 tape.push_tape(self._tape) 878 self._recording = True 879 880 def _pop_tape(self): 881 if not self._recording: 882 raise ValueError("Tape is not recording.") 883 tape.pop_tape(self._tape) 884 self._recording = False 885 886 @tf_contextlib.contextmanager 887 def _ensure_recording(self): 888 """Ensures that this tape is recording.""" 889 if not self._recording: 890 try: 891 self._push_tape() 892 yield 893 finally: 894 self._pop_tape() 895 else: 896 yield 897 898 def watch(self, tensor): 899 """Ensures that `tensor` is being traced by this tape. 900 901 Args: 902 tensor: a Tensor or list of Tensors. 903 904 Raises: 905 ValueError: if it encounters something that is not a tensor. 906 """ 907 for t in nest.flatten(tensor, expand_composites=True): 908 if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)): 909 raise ValueError("Passed in object of type {}, not tf.Tensor".format( 910 type(t))) 911 if not backprop_util.IsTrainable(t): 912 logging.log_first_n( 913 logging.WARN, "The dtype of the watched tensor must be " 914 "floating (e.g. tf.float32), got %r", 5, t.dtype) 915 if hasattr(t, "handle"): 916 # There are many variable-like objects, all of them currently have 917 # `handle` attribute that points to a tensor. If this changes, internals 918 # of watch_variable need to change as well. 919 tape.watch_variable(self._tape, t) 920 else: 921 tape.watch(self._tape, t) 922 923 @tf_contextlib.contextmanager 924 def stop_recording(self): 925 """Temporarily stops recording operations on this tape. 926 927 Operations executed while this context manager is active will not be 928 recorded on the tape. This is useful for reducing the memory used by tracing 929 all computations. 930 931 For example: 932 933 >>> x = tf.constant(4.0) 934 >>> with tf.GradientTape() as tape: 935 ... with tape.stop_recording(): 936 ... y = x ** 2 937 >>> dy_dx = tape.gradient(y, x) 938 >>> print(dy_dx) 939 None 940 941 Yields: 942 None 943 Raises: 944 RuntimeError: if the tape is not currently recording. 945 """ 946 if self._tape is None: 947 raise RuntimeError( 948 "Trying to stop recording a tape which is not recording.") 949 self._pop_tape() 950 try: 951 yield 952 finally: 953 self._push_tape() 954 955 def reset(self): 956 """Clears all information stored in this tape. 957 958 Equivalent to exiting and reentering the tape context manager with a new 959 tape. For example, the two following code blocks are equivalent: 960 961 ``` 962 with tf.GradientTape() as t: 963 loss = loss_fn() 964 with tf.GradientTape() as t: 965 loss += other_loss_fn() 966 t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn 967 968 969 # The following is equivalent to the above 970 with tf.GradientTape() as t: 971 loss = loss_fn() 972 t.reset() 973 loss += other_loss_fn() 974 t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn 975 ``` 976 977 This is useful if you don't want to exit the context manager for the tape, 978 or can't because the desired reset point is inside a control flow construct: 979 980 ``` 981 with tf.GradientTape() as t: 982 loss = ... 983 if loss > k: 984 t.reset() 985 ``` 986 """ 987 self._pop_tape() 988 self._tape = None 989 self._push_tape() 990 991 def watched_variables(self): 992 """Returns variables watched by this tape in order of construction.""" 993 if self._tape is not None: 994 self._watched_variables = self._tape.watched_variables() 995 return self._watched_variables 996 997 def gradient(self, 998 target, 999 sources, 1000 output_gradients=None, 1001 unconnected_gradients=UnconnectedGradients.NONE): 1002 """Computes the gradient using operations recorded in context of this tape. 1003 1004 Note: Unless you set `persistent=True` a GradientTape can only be used to 1005 compute one set of gradients (or jacobians). 1006 1007 Args: 1008 target: a list or nested structure of Tensors or Variables to be 1009 differentiated. 1010 sources: a list or nested structure of Tensors or Variables. `target` 1011 will be differentiated against elements in `sources`. 1012 output_gradients: a list of gradients, one for each element of 1013 target. Defaults to None. 1014 unconnected_gradients: a value which can either hold 'none' or 'zero' and 1015 alters the value which will be returned if the target and sources are 1016 unconnected. The possible values and effects are detailed in 1017 'UnconnectedGradients' and it defaults to 'none'. 1018 1019 Returns: 1020 a list or nested structure of Tensors (or IndexedSlices, or None), 1021 one for each element in `sources`. Returned structure is the same as 1022 the structure of `sources`. 1023 1024 Raises: 1025 RuntimeError: If called on a used, non-persistent tape. 1026 RuntimeError: If called inside the context of the tape. 1027 TypeError: If the target is a None object. 1028 ValueError: If the target is a variable or if unconnected gradients is 1029 called with an unknown value. 1030 """ 1031 if self._tape is None: 1032 raise RuntimeError("A non-persistent GradientTape can only be used to " 1033 "compute one set of gradients (or jacobians)") 1034 if self._recording: 1035 if not self._persistent: 1036 self._pop_tape() 1037 else: 1038 logging.log_first_n( 1039 logging.WARN, "Calling GradientTape.gradient on a persistent " 1040 "tape inside its context is significantly less " 1041 "efficient than calling it outside the context (it " 1042 "causes the gradient ops to be recorded on the " 1043 "tape, leading to increased CPU and memory usage). " 1044 "Only call GradientTape.gradient inside the " 1045 "context if you actually want to trace the " 1046 "gradient in order to compute higher order " 1047 "derivatives.", 1) 1048 1049 if target is None: 1050 raise TypeError("Target should be a list or nested structure" 1051 " of Tensors or Variables to be differentiated," 1052 " but recieved %r" % (target)) 1053 1054 flat_targets = [] 1055 for t in nest.flatten(target): 1056 if not backprop_util.IsTrainable(t): 1057 logging.vlog( 1058 logging.WARN, "The dtype of the target tensor must be " 1059 "floating (e.g. tf.float32) when calling GradientTape.gradient, " 1060 "got %r", t.dtype) 1061 if resource_variable_ops.is_resource_variable(t): 1062 with self: 1063 t = ops.convert_to_tensor(t) 1064 flat_targets.append(t) 1065 1066 flat_sources = nest.flatten(sources) 1067 flat_sources_raw = flat_sources 1068 flat_sources = [_handle_or_self(x) for x in flat_sources] 1069 for t in flat_sources_raw: 1070 if not backprop_util.IsTrainable(t): 1071 logging.vlog( 1072 logging.WARN, "The dtype of the source tensor must be " 1073 "floating (e.g. tf.float32) when calling GradientTape.gradient, " 1074 "got %r", t.dtype) 1075 if getattr(t, "is_packed", False): 1076 raise ValueError( 1077 "GradientTape.gradient is not supported on packed EagerTensors yet." 1078 ) 1079 1080 if output_gradients is not None: 1081 output_gradients = [None if x is None else ops.convert_to_tensor(x) 1082 for x in nest.flatten(output_gradients)] 1083 1084 flat_grad = imperative_grad.imperative_grad( 1085 self._tape, 1086 flat_targets, 1087 flat_sources, 1088 output_gradients=output_gradients, 1089 sources_raw=flat_sources_raw, 1090 unconnected_gradients=unconnected_gradients) 1091 1092 if not self._persistent: 1093 # Keep track of watched variables before setting tape to None 1094 self._watched_variables = self._tape.watched_variables() 1095 self._tape = None 1096 1097 grad = nest.pack_sequence_as(sources, flat_grad) 1098 return grad 1099 1100 def jacobian(self, 1101 target, 1102 sources, 1103 unconnected_gradients=UnconnectedGradients.NONE, 1104 parallel_iterations=None, 1105 experimental_use_pfor=True): 1106 """Computes the jacobian using operations recorded in context of this tape. 1107 1108 Note: Unless you set `persistent=True` a GradientTape can only be used to 1109 compute one set of gradients (or jacobians). 1110 1111 Note: By default the jacobian implementation uses parallel for (pfor), which 1112 creates a tf.function under the hood for each jacobian call. For better 1113 performance, and to avoid recompilation and vectorization rewrites on each 1114 call, enclose GradientTape code in @tf.function. 1115 1116 See[wikipedia 1117 article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant) 1118 for the definition of a Jacobian. 1119 1120 Example usage: 1121 1122 ```python 1123 with tf.GradientTape() as g: 1124 x = tf.constant([1.0, 2.0]) 1125 g.watch(x) 1126 y = x * x 1127 jacobian = g.jacobian(y, x) 1128 # jacobian value is [[2., 0.], [0., 4.]] 1129 ``` 1130 1131 Args: 1132 target: Tensor to be differentiated. 1133 sources: a list or nested structure of Tensors or Variables. `target` 1134 will be differentiated against elements in `sources`. 1135 unconnected_gradients: a value which can either hold 'none' or 'zero' and 1136 alters the value which will be returned if the target and sources are 1137 unconnected. The possible values and effects are detailed in 1138 'UnconnectedGradients' and it defaults to 'none'. 1139 parallel_iterations: A knob to control how many iterations are dispatched 1140 in parallel. This knob can be used to control the total memory usage. 1141 experimental_use_pfor: If true, vectorizes the jacobian computation. Else 1142 falls back to a sequential while_loop. Vectorization can sometimes fail 1143 or lead to excessive memory usage. This option can be used to disable 1144 vectorization in such cases. 1145 1146 Returns: 1147 A list or nested structure of Tensors (or None), one for each element in 1148 `sources`. Returned structure is the same as the structure of `sources`. 1149 Note if any gradient is sparse (IndexedSlices), jacobian function 1150 currently makes it dense and returns a Tensor instead. This may change in 1151 the future. 1152 1153 1154 Raises: 1155 RuntimeError: If called on a used, non-persistent tape. 1156 RuntimeError: If called on a non-persistent tape with eager execution 1157 enabled and without enabling experimental_use_pfor. 1158 ValueError: If vectorization of jacobian computation fails. 1159 """ 1160 if self._tape is None: 1161 raise RuntimeError("A non-persistent GradientTape can only be used to " 1162 "compute one set of gradients (or jacobians)") 1163 1164 flat_sources = nest.flatten(sources) 1165 target_static_shape = target.shape 1166 target_shape = array_ops.shape(target) 1167 # Note that we push and pop the tape here and below. This is needed since we 1168 # need gradients through the enclosed operations. 1169 with self._ensure_recording(): 1170 target = array_ops.reshape(target, [-1]) 1171 1172 def loop_fn(i): 1173 with self._ensure_recording(): 1174 y = array_ops.gather(target, i) 1175 return self.gradient(y, flat_sources, 1176 unconnected_gradients=unconnected_gradients) 1177 1178 try: 1179 target_size = int(target.shape[0]) 1180 except TypeError: 1181 target_size = array_ops.shape(target)[0] 1182 1183 if experimental_use_pfor: 1184 try: 1185 output = pfor_ops.pfor(loop_fn, target_size, 1186 parallel_iterations=parallel_iterations) 1187 except ValueError as err: 1188 six.reraise( 1189 ValueError, 1190 ValueError( 1191 str(err) + "\nEncountered an exception while vectorizing the " 1192 "jacobian computation. Vectorization can be disabled by setting" 1193 " experimental_use_pfor to False."), 1194 sys.exc_info()[2]) 1195 else: 1196 if context.executing_eagerly() and not self._persistent: 1197 raise RuntimeError( 1198 "GradientTape must be created with persistent=True" 1199 " to compute the jacobian with eager execution enabled and with " 1200 " experimental_use_pfor set to False.") 1201 output = pfor_ops.for_loop( 1202 loop_fn, [target.dtype] * len(flat_sources), target_size, 1203 parallel_iterations=parallel_iterations) 1204 1205 for i, out in enumerate(output): 1206 if out is not None: 1207 new_shape = array_ops.concat( 1208 [target_shape, array_ops.shape(out)[1:]], axis=0) 1209 out = array_ops.reshape(out, new_shape) 1210 if context.executing_eagerly(): 1211 out.set_shape(target_static_shape.concatenate(flat_sources[i].shape)) 1212 output[i] = out 1213 1214 return nest.pack_sequence_as(sources, output) 1215 1216 def batch_jacobian(self, 1217 target, 1218 source, 1219 unconnected_gradients=UnconnectedGradients.NONE, 1220 parallel_iterations=None, 1221 experimental_use_pfor=True): 1222 """Computes and stacks per-example jacobians. 1223 1224 See [wikipedia article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant) 1225 for the definition of a Jacobian. This function is essentially an efficient 1226 implementation of the following: 1227 1228 `tf.stack([self.jacobian(y[i], x[i]) for i in range(x.shape[0])])`. 1229 1230 Note that compared to `GradientTape.jacobian` which computes gradient of 1231 each output value w.r.t each input value, this function is useful when 1232 `target[i,...]` is independent of `source[j,...]` for `j != i`. This 1233 assumption allows more efficient computation as compared to 1234 `GradientTape.jacobian`. The output, as well as intermediate activations, 1235 are lower dimensional and avoid a bunch of redundant zeros which would 1236 result in the jacobian computation given the independence assumption. 1237 1238 Note: Unless you set `persistent=True` a GradientTape can only be used to 1239 compute one set of gradients (or jacobians). 1240 1241 Note: By default the batch_jacobian implementation uses parallel for (pfor), 1242 which creates a tf.function under the hood for each batch_jacobian call. 1243 For better performance, and to avoid recompilation and vectorization 1244 rewrites on each call, enclose GradientTape code in @tf.function. 1245 1246 1247 Example usage: 1248 1249 ```python 1250 with tf.GradientTape() as g: 1251 x = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32) 1252 g.watch(x) 1253 y = x * x 1254 batch_jacobian = g.batch_jacobian(y, x) 1255 # batch_jacobian is [[[2, 0], [0, 4]], [[6, 0], [0, 8]]] 1256 ``` 1257 1258 Args: 1259 target: A tensor with rank 2 or higher and with shape [b, y1, ..., y_n]. 1260 `target[i,...]` should only depend on `source[i,...]`. 1261 source: A tensor with rank 2 or higher and with shape [b, x1, ..., x_m]. 1262 unconnected_gradients: a value which can either hold 'none' or 'zero' and 1263 alters the value which will be returned if the target and sources are 1264 unconnected. The possible values and effects are detailed in 1265 'UnconnectedGradients' and it defaults to 'none'. 1266 parallel_iterations: A knob to control how many iterations are dispatched 1267 in parallel. This knob can be used to control the total memory usage. 1268 experimental_use_pfor: If true, uses pfor for computing the Jacobian. Else 1269 uses a tf.while_loop. 1270 1271 Returns: 1272 A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]` 1273 is the jacobian of `target[i, ...]` w.r.t. `source[i, ...]`, i.e. stacked 1274 per-example jacobians. 1275 1276 Raises: 1277 RuntimeError: If called on a used, non-persistent tape. 1278 RuntimeError: If called on a non-persistent tape with eager execution 1279 enabled and without enabling experimental_use_pfor. 1280 ValueError: If vectorization of jacobian computation fails or if first 1281 dimension of `target` and `source` do not match. 1282 """ 1283 if self._tape is None: 1284 raise RuntimeError("A non-persistent GradientTape can only be used to" 1285 "compute one set of gradients (or jacobians)") 1286 target_shape = target.shape 1287 if target_shape.rank is None: 1288 dim = tensor_shape.Dimension(None) 1289 else: 1290 dim = target_shape.dims[0] 1291 if not (target_shape.with_rank_at_least(2) and 1292 source.shape.with_rank_at_least(2) and 1293 dim.is_compatible_with(source.shape[0])): 1294 raise ValueError( 1295 "Need first dimension of target shape (%s) and " 1296 "source shape (%s) to match." % (target.shape, source.shape)) 1297 if target_shape.is_fully_defined(): 1298 batch_size = int(target_shape[0]) 1299 target_row_size = target_shape.num_elements() // batch_size 1300 else: 1301 target_shape = array_ops.shape(target) 1302 batch_size = target_shape[0] 1303 target_row_size = array_ops.size(target) // batch_size 1304 source_shape = array_ops.shape(source) 1305 # Flatten target to 2-D. 1306 # Note that we push and pop the tape here and below. This is needed since we 1307 # need gradients through the enclosed operations. 1308 with self._ensure_recording(): 1309 with ops.control_dependencies( 1310 [check_ops.assert_equal(batch_size, source_shape[0])]): 1311 target = array_ops.reshape(target, [batch_size, target_row_size]) 1312 1313 run_once = False 1314 1315 def loop_fn(i): 1316 nonlocal run_once 1317 if run_once and not self._persistent: 1318 if parallel_iterations is not None: 1319 raise RuntimeError( 1320 "GradientTape must be created with persistent=True" 1321 " to compute the batch_jacobian with parallel_iterations.") 1322 else: 1323 raise RuntimeError( 1324 "GradientTape must be created with persistent=True" 1325 " to compute the batch_jacobian.") 1326 run_once = True 1327 1328 with self._ensure_recording(): 1329 y = array_ops.gather(target, i, axis=1) 1330 return self.gradient(y, source, 1331 unconnected_gradients=unconnected_gradients) 1332 1333 if experimental_use_pfor: 1334 try: 1335 output = pfor_ops.pfor(loop_fn, target_row_size, 1336 parallel_iterations=parallel_iterations) 1337 except ValueError as err: 1338 six.reraise( 1339 ValueError, 1340 ValueError( 1341 str(err) + "\nEncountered an exception while vectorizing the " 1342 "batch_jacobian computation. Vectorization can be disabled by " 1343 "setting experimental_use_pfor to False."), 1344 sys.exc_info()[2]) 1345 else: 1346 if context.executing_eagerly() and not self._persistent: 1347 raise RuntimeError( 1348 "GradientTape must be created with persistent=True" 1349 " to compute the batch_jacobian with eager execution enabled and " 1350 " with experimental_use_pfor set to False.") 1351 output = pfor_ops.for_loop(loop_fn, target.dtype, target_row_size, 1352 parallel_iterations=parallel_iterations) 1353 new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0) 1354 if output is None: 1355 # Note that this block is returning zeros when it could use `None` to 1356 # represent unconnected gradients. This is to maintain compatibility with 1357 # the previous behavior, which ignored `unconnected_gradients`. 1358 output = array_ops.zeros(new_shape, target.dtype) 1359 return output 1360 else: 1361 output = array_ops.reshape(output, 1362 [target_row_size, batch_size, -1]) 1363 output = array_ops.transpose(output, [1, 0, 2]) 1364 1365 output = array_ops.reshape(output, new_shape) 1366 return output 1367