1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implements the graph generation for computation of gradients.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23 24from six.moves import xrange, zip # pylint: disable=redefined-builtin 25 26from tensorflow.core.framework import attr_value_pb2 27from tensorflow.python import pywrap_tfe 28from tensorflow.python.eager import backprop 29from tensorflow.python.eager import backprop_util 30from tensorflow.python.eager import context 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import function as framework_function 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.framework.func_graph import FuncGraph 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import control_flow_state 39from tensorflow.python.ops import control_flow_util 40from tensorflow.python.ops import default_gradient 41from tensorflow.python.ops import functional_ops 42from tensorflow.python.ops import math_ops 43from tensorflow.python.ops import resource_variable_ops 44from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 45from tensorflow.python.platform import tf_logging as logging 46from tensorflow.python.util import compat 47from tensorflow.python.util import object_identity 48from tensorflow.python.util.compat import collections_abc 49from tensorflow.python.util.tf_export import tf_export 50 51 52def _MarkReachedOps(from_ops, reached_ops, func_graphs): 53 """Mark all ops reached from "from_ops". 54 55 Args: 56 from_ops: list of Operations. 57 reached_ops: set of Operations. 58 func_graphs: list of FuncGraphs. This method will traverse through 59 these functions if they capture from_ops or any reachable ops. 60 """ 61 queue = collections.deque() 62 queue.extend(from_ops) 63 while queue: 64 op = queue.popleft() 65 if op not in reached_ops: 66 reached_ops.add(op) 67 for output in op.outputs: 68 if _IsBackpropagatable(output): 69 queue.extend(_Consumers(output, func_graphs)) 70 71 72def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, 73 xs_set): 74 """Initialize the pending count for ops between two lists of Operations. 75 76 'pending_count[op]' indicates the number of backprop inputs 77 to this operation. 78 79 Args: 80 to_ops: list of Operations. 81 from_ops: list of Operations. 82 colocate_gradients_with_ops: Python bool. See docstring of gradients(). 83 func_graphs: list of FuncGraphs. This method will traverse through 84 these functions if they capture from_ops or any reachable ops. This is 85 useful if to_ops occur in a function and from_ops are in an outer function 86 or graph. 87 xs_set: ObjectIdentitySet of Tensors. 88 89 Returns: 90 A tuple containing: (1) the subset of to_ops reachable from from_ops by a 91 path of zero or more backpropagatable tensors, (2) a mapping from operation 92 to the number of backprop inputs to that op, and (3) a ControlFlowState 93 object which is not None if the ops between from_ops and to_ops contain 94 control flow loops. 95 """ 96 # Mark reachable ops from from_ops. 97 reached_ops = set() 98 _MarkReachedOps(from_ops, reached_ops, func_graphs) 99 # X in reached_ops iff X is reachable from from_ops by a path of zero or more 100 # backpropagatable tensors. 101 102 reachable_to_ops = set(op for op in to_ops if op in reached_ops) 103 104 # Mark between ops. 105 between_ops = set() 106 between_op_list = [] 107 queue = collections.deque() 108 queue.extend(to_ops) 109 while queue: 110 op = queue.popleft() 111 # We are interested in this op. 112 if op in reached_ops: 113 between_ops.add(op) 114 between_op_list.append(op) 115 # Clear the boolean so we won't add the inputs again. 116 reached_ops.remove(op) 117 for inp in _NonEagerInputs(op, xs_set): 118 queue.append(inp.op) 119 # X in between_ops iff X is on a path of zero or more backpropagatable tensors 120 # between from_ops and to_ops 121 122 # 'loop_state' is None if there are no while loops. 123 loop_state = control_flow_state.MaybeCreateControlFlowState( 124 between_op_list, between_ops, colocate_gradients_with_ops) 125 126 # Initialize pending count for between ops. 127 pending_count = collections.defaultdict(int) 128 for op in between_op_list: 129 for x in _NonEagerInputs(op, xs_set): 130 if x.op in between_ops: 131 pending_count[x.op] += 1 132 133 return reachable_to_ops, pending_count, loop_state 134 135 136def _AsList(x): 137 return x if isinstance(x, (list, tuple)) else [x] 138 139 140def _DefaultGradYs(grad_ys, 141 ys, 142 colocate_gradients_with_ops, 143 gradient_uid="__unsupported__"): 144 """Fill in default values for grad_ys. 145 146 Args: 147 grad_ys: List of gradients, can contain None. 148 ys: List of tensors. 149 colocate_gradients_with_ops: If True, try colocating gradients with 150 the corresponding op. 151 gradient_uid: A unique identifier within the graph indicating 152 which invocation of gradients is being executed. Used to cluster 153 ops for compilation. 154 155 Returns: 156 A list of gradients to use, without None. 157 158 Raises: 159 ValueError: If sizes of gradients and inputs don't match 160 TypeError: If type of any gradient is not valid for its input. 161 """ 162 if len(grad_ys) != len(ys): 163 raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys))) 164 grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y") 165 new_grad_ys = [] 166 for i, (y, grad_y) in enumerate(zip(ys, grad_ys)): 167 with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops): 168 if grad_y is None: 169 if y.dtype.is_complex: 170 raise TypeError( 171 "Gradients of complex tensors must set grad_ys (y.dtype = %r)" % 172 y.dtype) 173 new_grad_ys.append( 174 array_ops.ones( 175 array_ops.shape(y), dtype=y.dtype, name="grad_ys_%d" % i)) 176 continue 177 if y.dtype.is_floating or y.dtype.is_integer: 178 if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer: 179 raise TypeError( 180 "Gradient type %s generated for real or " 181 "integer-valued tensor %s with type %s must be " 182 "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y, 183 dtypes.as_dtype(y.dtype).name)) 184 elif y.dtype.is_complex: 185 if not grad_y.dtype.is_complex: 186 raise TypeError( 187 "Gradient type %s generated for complex-valued " 188 "tensor %s with type %s must be real" % (dtypes.as_dtype( 189 grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) 190 elif y.dtype == dtypes.variant: 191 if grad_y.dtype != dtypes.variant: 192 raise TypeError( 193 "Gradient type %s generated for variant " 194 "tensor %s with type %s must be variant" % (dtypes.as_dtype( 195 grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) 196 elif y.dtype == dtypes.resource: 197 # We assume y is the handle of a ResourceVariable. The gradient of a 198 # ResourceVariable should be a numeric value, not another resource. 199 if grad_y.dtype == dtypes.resource: 200 raise TypeError("Input gradient %s for resource tensor %s should not " 201 "be a resource" % (grad_y, y)) 202 else: 203 raise TypeError( 204 "Tensor %s with type %s must be numeric " 205 "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name)) 206 # Create a grad_y tensor in the name scope of the gradient. 207 # Required for TensorArrays to identify which gradient call a 208 # grad_y value is coming from. 209 if isinstance(grad_y, ops.IndexedSlices): 210 new_grad_ys.append( 211 ops.IndexedSlices( 212 indices=(array_ops.identity( 213 grad_y.indices, name="grad_ys_%d_indices" % i) 214 if isinstance(grad_y.indices, ops.Tensor) else 215 grad_y.indices), 216 values=(array_ops.identity( 217 grad_y.values, name="grad_ys_%d_values" % i) if isinstance( 218 grad_y.values, ops.Tensor) else grad_y.values), 219 dense_shape=(array_ops.identity( 220 grad_y.dense_shape, name="grad_ys_%d_shape" % i) 221 if isinstance(grad_y.dense_shape, ops.Tensor) else 222 grad_y.dense_shape))) 223 else: 224 new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i)) 225 226 return new_grad_ys 227 228 229def _IsBackpropagatable(tensor): 230 if backprop_util.IsTrainable(tensor): 231 return True 232 dtype = dtypes.as_dtype(tensor.dtype) 233 return dtype.base_dtype == dtypes.bfloat16 234 235 236def _VerifyGeneratedGradients(grads, op): 237 """Verify that gradients are valid in number and type. 238 239 Args: 240 grads: List of generated gradients. 241 op: Operation for which the gradients where generated. 242 243 Raises: 244 ValueError: if sizes of gradients and inputs don't match. 245 TypeError: if type of any gradient is not valid for its input. 246 """ 247 # While ops have inputs added to them during the gradient computation, so we 248 # skip the below check. See while_v2 for details. 249 if op.type == "While" or op.type == "StatelessWhile": 250 return 251 252 if len(grads) != len(op.inputs): 253 raise ValueError("Num gradients %d generated for op %s do not match num " 254 "inputs %d" % (len(grads), op.node_def, len(op.inputs))) 255 256 257def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set): 258 """The set of ops that terminate the gradient computation. 259 260 This computes the frontier of the forward graph *before* which backprop 261 should stop. Operations in the returned set will not be differentiated. 262 This set is defined as the subset of `from_ops` containing ops that have 263 no predecessor in `from_ops`. `pending_count` is the result of 264 `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops` 265 iff pending_count[op] > 0. 266 267 In addition, none of `stop_gradient_ops` will be differentiated. 268 269 Args: 270 from_ops: list of Operations. 271 stop_gradient_ops: list of Operations never to backprop through. 272 pending_count: mapping from operation to number of backprop inputs. 273 xs_set: ObjectIdentitySet of Tensors. 274 275 Returns: 276 The set of operations. 277 """ 278 stop_ops = set() 279 for op in from_ops: 280 is_stop_op = True 281 for inp in _NonEagerInputs(op, xs_set): 282 if pending_count[inp.op] > 0: 283 is_stop_op = False 284 break 285 if is_stop_op: 286 stop_ops.add(op) 287 stop_ops.update(op for op in stop_gradient_ops) 288 return stop_ops 289 290 291@contextlib.contextmanager 292def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pylint: disable=invalid-name 293 """Context to colocate with `op` if `colocate_gradients_with_ops`.""" 294 if colocate_gradients_with_ops: 295 with ops._colocate_with_for_gradient(op, gradient_uid): # pylint: disable=protected-access 296 yield 297 else: 298 yield 299 300 301def _IsPartitionedCall(op): 302 return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall" 303 304 305def _SymGrad(op, out_grads): 306 """Backprop through a function call node op given its outputs' gradients.""" 307 f_in = [x for x in op.inputs] + out_grads 308 f_types = [default_gradient.get_zeros_dtype(x) for x in op.inputs] 309 f = attr_value_pb2.NameAttrList() 310 if _IsPartitionedCall(op): 311 f.name = op.get_attr("f").name 312 else: 313 f.name = op.type 314 for k in op.node_def.attr: 315 f.attr[k].CopyFrom(op.node_def.attr[k]) 316 in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f) 317 return in_grads 318 319 320def _MaybeCompile(scope, op, func, grad_fn): 321 """Compile the calculation in grad_fn if op was marked as compiled.""" 322 scope = scope.rstrip("/").replace("/", "_") 323 if func is not None: 324 xla_compile = func.definition.attr["_XlaCompile"].b 325 xla_separate_compiled_gradients = func.definition.attr[ 326 "_XlaSeparateCompiledGradients"].b 327 xla_scope = func.definition.attr["_XlaScope"].s.decode() 328 else: 329 try: 330 xla_compile = op.get_attr("_XlaCompile") 331 xla_separate_compiled_gradients = op.get_attr( 332 "_XlaSeparateCompiledGradients") 333 xla_scope = op.get_attr("_XlaScope").decode() 334 except ValueError: 335 xla_compile = False 336 337 if not xla_compile: 338 return grad_fn() # Exit early 339 340 # If the gradients are supposed to be compiled separately, we give them a 341 # _XlaScope name that is based on the name_scope of the gradients. Otherwise 342 # they just inherit the existing _XlaScope name, which lets them be merged 343 # together with the non-gradient computation. 344 if xla_separate_compiled_gradients: 345 xla_grad_scope = "%s_grad_%s" % (xla_scope, scope) 346 else: 347 xla_grad_scope = xla_scope 348 349 attrs = { 350 "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile), 351 "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode()) 352 } 353 with ops.get_default_graph()._attr_scope(attrs): # pylint: disable=protected-access 354 return grad_fn() 355 356 357def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set): 358 """Raises an error if we backprop through a loop var.""" 359 # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error 360 # message. 361 target_op = None 362 queue = collections.deque([op]) 363 visited = set() 364 while queue: 365 curr_op = queue.popleft() 366 if curr_op in visited: continue 367 visited.add(curr_op) 368 if curr_op in from_ops: 369 target_op = curr_op 370 break 371 queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set)) 372 assert target_op 373 raise ValueError( 374 "Cannot compute gradient inside while loop with respect to op '%s'. " 375 "We do not support taking the gradient wrt or through the initial value " 376 "of a loop variable. Gradients can be computed through loop invariants " 377 "or wrt the input parameters to the loop body." 378 % target_op.name) 379 380 381def _IsFunction(graph): 382 return (isinstance(graph, FuncGraph) or 383 isinstance(graph, framework_function._FuncGraph)) # pylint: disable=protected-access 384 385 386def _Captures(func_graph): 387 if isinstance(func_graph, FuncGraph): 388 return func_graph.captures 389 else: 390 assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access 391 return func_graph.captures 392 393 394def _MaybeCaptured(t): 395 """If t is a captured value placeholder, returns the original captured value. 396 397 Args: 398 t: Tensor 399 400 Returns: 401 A tensor, potentially from a different Graph/FuncGraph. 402 """ 403 # pylint: disable=protected-access 404 if (not isinstance(t, ops.EagerTensor) and 405 _IsFunction(t.op.graph) and t.op.type == "Placeholder"): 406 for input_t, placeholder_t in _Captures(t.op.graph): 407 if t is placeholder_t: 408 return _MaybeCaptured(input_t) 409 # pylint: enable=protected-access 410 return t 411 412 413def _NonEagerInputs(op, xs_set): 414 """Returns the inputs of op, crossing closure boundaries where necessary. 415 416 Does not return any captured EagerTensors, i.e., the number of tensors 417 returned may be less than the actual number of inputs. 418 419 Args: 420 op: Operation 421 xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t. 422 423 Returns: 424 A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op 425 is in a FuncGraph and has captured inputs. 426 """ 427 return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)] 428 429 430# TODO(skyewm): plumbing xs through everywhere is ugly, consider making 431# _GradientsHelper a class with xs as a member variable. 432def _Inputs(op, xs_set): 433 """Returns the inputs of op, crossing closure boundaries where necessary. 434 435 Args: 436 op: Operation 437 xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t. 438 439 Returns: 440 A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op 441 is in a FuncGraph and has captured inputs. 442 """ 443 if _IsFunction(op.graph): # pylint: disable=protected-access 444 inputs = [] 445 for t in op.inputs: 446 # If we're differentiating w.r.t. `t`, do not attempt to traverse through 447 # it to a captured value. The algorithm needs to "see" `t` in this case, 448 # even if it's a function input for a captured value, whereas usually we'd 449 # like to traverse through these closures as if the captured value was the 450 # direct input to op. 451 if t not in xs_set: 452 t = _MaybeCaptured(t) 453 inputs.append(t) 454 return inputs 455 else: 456 return op.inputs 457 458 459def _Consumers(t, func_graphs): 460 """Returns the consumers of t, crossing closure boundaries where necessary. 461 462 Args: 463 t: Tensor 464 func_graphs: a list of FuncGraphs that may have captured t. 465 466 Returns: 467 A list of tensors. The tensors will be from the current graph and/or 468 func_graphs. 469 """ 470 consumers = t.consumers() 471 for func in func_graphs: 472 for input_t, placeholder in _Captures(func): 473 if input_t is t: 474 consumers.extend(_Consumers(placeholder, func_graphs)) 475 return consumers 476 477 478def _GradientsHelper(ys, 479 xs, 480 grad_ys=None, 481 name="gradients", 482 colocate_gradients_with_ops=False, 483 gate_gradients=False, 484 aggregation_method=None, 485 stop_gradients=None, 486 unconnected_gradients=UnconnectedGradients.NONE, 487 src_graph=None): 488 """Implementation of gradients().""" 489 if context.executing_eagerly(): 490 raise RuntimeError("tf.gradients is not supported when eager execution " 491 "is enabled. Use tf.GradientTape instead.") 492 if src_graph is None: 493 src_graph = ops.get_default_graph() 494 try: 495 unconnected_gradients = UnconnectedGradients(unconnected_gradients) 496 except ValueError: 497 raise ValueError( 498 "Unknown value for unconnected_gradients: %r" % unconnected_gradients) 499 500 # If src_graph is a _FuncGraph (i.e. a function body), gather it and all 501 # ancestor graphs. This is necessary for correctly handling captured values. 502 func_graphs = [] 503 curr_graph = src_graph 504 while _IsFunction(curr_graph): 505 func_graphs.append(curr_graph) 506 if isinstance(curr_graph, FuncGraph): 507 curr_graph = curr_graph.outer_graph 508 else: 509 assert isinstance(curr_graph, framework_function._FuncGraph) # pylint: disable=protected-access 510 curr_graph = curr_graph._outer_graph # pylint: disable=protected-access 511 512 ys = _AsList(ys) 513 xs = _AsList(xs) 514 stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients) 515 if grad_ys is None: 516 grad_ys = [None] * len(ys) 517 else: 518 grad_ys = _AsList(grad_ys) 519 520 with ops.name_scope( 521 name, "gradients", 522 list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope: 523 # Get a uid for this call to gradients that can be used to help 524 # cluster ops for compilation. 525 gradient_uid = ops.get_default_graph().unique_name("uid") 526 ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y") 527 xs = [ 528 x.handle if resource_variable_ops.is_resource_variable(x) else x 529 for x in xs 530 ] 531 xs = ops.internal_convert_n_to_tensor_or_indexed_slices( 532 xs, name="x", as_ref=True) 533 xs_set = object_identity.ObjectIdentitySet(xs) 534 grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, 535 gradient_uid) 536 537 # The approach we take here is as follows: Create a list of all ops in the 538 # subgraph between the ys and xs. Visit these ops in reverse order of ids 539 # to ensure that when we visit an op the gradients w.r.t its outputs have 540 # been collected. Then aggregate these gradients if needed, call the op's 541 # gradient function, and add the generated gradients to the gradients for 542 # its input. 543 544 # Initialize the pending count for ops in the connected subgraph from ys 545 # to the xs. 546 to_ops = [t.op for t in ys] 547 from_ops = [t.op for t in xs] 548 stop_gradient_ops = [t.op for t in stop_gradients] 549 reachable_to_ops, pending_count, loop_state = _PendingCount( 550 to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set) 551 552 # Iterate over the collected ops. 553 # 554 # grads: op => list of gradients received on each output endpoint of the 555 # op. The gradients for each endpoint are initially collected as a list. 556 # When it is time to call the op's gradient function, for each endpoint we 557 # aggregate the list of received gradients into a Add() Operation if there 558 # is more than one. 559 grads = {} 560 561 # Add the initial gradients for the ys. 562 for y, grad_y in zip(ys, grad_ys): 563 _SetGrad(grads, y, grad_y) 564 565 # Initialize queue with to_ops. 566 queue = collections.deque() 567 # Add the ops in 'to_ops' into the queue. 568 to_ops_set = set() 569 for op in to_ops: 570 # 'ready' handles the case where one output gradient relies on 571 # another output's gradient. 572 ready = (pending_count[op] == 0) 573 if ready and op not in to_ops_set and op in reachable_to_ops: 574 to_ops_set.add(op) 575 queue.append(op) 576 577 if loop_state: 578 loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set) 579 for y in loop_exits: 580 if backprop_util.IsTrainable(y): 581 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) 582 queue.append(y.op) 583 584 stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set) 585 while queue: 586 # generate gradient subgraph for op. 587 op = queue.popleft() 588 with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): 589 if loop_state: 590 loop_state.EnterGradWhileContext(op, before=True) 591 out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, 592 aggregation_method) 593 if loop_state: 594 loop_state.ExitGradWhileContext(op, before=True) 595 596 grad_fn = None 597 func_call = None 598 is_partitioned_call = _IsPartitionedCall(op) 599 # pylint: disable=protected-access 600 is_func_call = ( 601 src_graph._is_function(op.type) or is_partitioned_call) 602 # pylint: enable=protected-access 603 has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) 604 if has_out_grads and (op not in stop_ops): 605 try: 606 grad_fn = ops.get_gradient_function(op) 607 except LookupError: 608 if is_func_call: 609 if is_partitioned_call: 610 func_name = compat.as_bytes(op.get_attr("f").name) 611 func_call = src_graph._get_function( # pylint: disable=protected-access 612 func_name) 613 # When a graph is imported, the FunctionDefs are not copied over 614 # to each sub-graph so we recursively search the outer graphs 615 # for the FunctionDef. 616 if not func_call and hasattr(src_graph, "outer_graph"): 617 graph = src_graph.outer_graph 618 while graph is not None: 619 func_call = graph._get_function(func_name) # pylint: disable=protected-access 620 if func_call is not None: 621 break 622 if hasattr(graph, "outer_graph"): 623 graph = graph.outer_graph 624 else: 625 break 626 else: 627 func_call = src_graph._get_function(op.type) # pylint: disable=protected-access 628 # Note that __defun is not set if the graph is 629 # imported. If it's set, we prefer to access the original 630 # defun. 631 func_call = getattr(op, "__defun", func_call) 632 grad_fn = func_call.python_grad_func 633 else: 634 raise LookupError( 635 "No gradient defined for operation" + 636 "'%s' (op type: %s). " %(op.name, op.type) + 637 "In general every operation must have an associated " + 638 "`@tf.RegisterGradient` for correct autodiff, which this " + 639 "op is lacking. If you want to pretend this " + 640 "operation is a constant in your program, you may insert " + 641 "`tf.stop_gradient`. This can be useful to silence the " + 642 "error in cases where you know gradients are not needed, " + 643 "e.g. the forward pass of tf.custom_gradient. " + 644 "Please see more details in " + 645 "https://www.tensorflow.org/api_docs/python/tf/custom_gradient.") # pylint: disable=line-too-long 646 if loop_state: 647 loop_state.EnterGradWhileContext(op, before=False) 648 649 # NOTE(skyewm): We don't support computing gradients wrt a loop variable 650 # unless it's within the context of a single iteration (i.e. the 651 # gradient is wrt to the loop parameter in the body function, not wrt or 652 # through the initial value). This means if we're in a while loop 653 # context, we should never see a switch node from this context. 654 # pylint: disable=protected-access 655 if (control_flow_util.IsSwitch(op) and 656 op._control_flow_context is not None and 657 op._control_flow_context.IsWhileContext() and 658 op._control_flow_context == 659 ops.get_default_graph()._get_control_flow_context()): 660 _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set) 661 # pylint: enable=protected-access 662 663 if (grad_fn or is_func_call) and has_out_grads: 664 # NOTE: If _AggregatedGrads didn't compute a value for the i'th 665 # output, it means that the cost does not depend on output[i], 666 # therefore dC/doutput[i] is 0. 667 for i, out_grad in enumerate(out_grads): 668 if (not isinstance(out_grad, ops.Tensor) and not out_grad) and ( 669 (not grad_fn and is_func_call) 670 or backprop_util.IsTrainable(op.outputs[i])): 671 # Only trainable outputs or outputs for a function call that 672 # will use SymbolicGradient get a zero gradient. Gradient 673 # functions should ignore the gradient for other outputs. 674 # TODO(apassos) gradients of resource handles might be an 675 # issue here because of zeros. 676 if loop_state: 677 out_grads[i] = loop_state.ZerosLikeV1WhileLoop(op, i) 678 elif default_gradient.supports_default_grad(op.outputs[i]): 679 # TODO(b/143286622): The supports_default_grad check is needed 680 # because While op emits non-differentiable resource tensors 681 # as outputs. Remove this check when that is not the case. 682 out_grads[i] = control_flow_state.ZerosLike(op, i) 683 with ops.name_scope(op.name + "_grad"): 684 # pylint: disable=protected-access 685 with src_graph._original_op(op): 686 # pylint: enable=protected-access 687 if grad_fn: 688 # If grad_fn was found, do not use SymbolicGradient even for 689 # functions. 690 in_grads = _MaybeCompile(grad_scope, op, func_call, 691 lambda: grad_fn(op, *out_grads)) 692 else: 693 # For function call ops, we add a 'SymbolicGradient' 694 # node to the graph to compute gradients. 695 in_grads = _MaybeCompile(grad_scope, op, func_call, 696 lambda: _SymGrad(op, out_grads)) 697 in_grads = _AsList(in_grads) 698 _VerifyGeneratedGradients(in_grads, op) 699 if gate_gradients and len([x for x in in_grads 700 if x is not None]) > 1: 701 with ops.device(None): 702 with ops._colocate_with_for_gradient( # pylint: disable=protected-access 703 None, 704 gradient_uid, 705 ignore_existing=True): 706 in_grads = control_flow_ops.tuple(in_grads) 707 _LogOpGradients(op, out_grads, in_grads) 708 else: 709 # If no grad_fn is defined or none of out_grads is available, 710 # just propagate a list of None backwards. 711 in_grads = [None] * len(_Inputs(op, xs_set)) 712 # Note: we don't filter out eager inputs here because the inputs need to 713 # line up with in_grads. 714 for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)): 715 if in_grad is not None: 716 if (isinstance(in_grad, ops.Tensor) and 717 t_in.dtype != dtypes.resource): 718 try: 719 in_grad.set_shape(t_in.get_shape()) 720 except ValueError: 721 raise ValueError( 722 "Incompatible shapes between op input and calculated " 723 "input gradient. Forward operation: %s. Input index: %d. " 724 "Original input shape: %s. " 725 "Calculated input gradient shape: %s" % 726 (op.name, i, t_in.shape, in_grad.shape)) 727 if not isinstance(t_in, ops.EagerTensor): 728 _SetGrad(grads, t_in, in_grad) 729 if loop_state: 730 loop_state.ExitGradWhileContext(op, before=False) 731 732 # Update pending count for the inputs of op and enqueue ready ops. 733 _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, 734 xs_set) 735 736 if loop_state: 737 loop_state.PostProcessing() 738 return [_GetGrad(grads, x, unconnected_gradients) for x in xs] 739 740 741def _HasAnyNotNoneGrads(grads, op): 742 """Return true iff op has real gradient.""" 743 out_grads = _GetGrads(grads, op) 744 for out_grad in out_grads: 745 if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)): 746 return True 747 if out_grad and isinstance(out_grad, collections_abc.Sequence): 748 if any(g is not None for g in out_grad): 749 return True 750 return False 751 752 753def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, 754 xs_set): 755 """Update pending count for the inputs of op and enqueue ready ops.""" 756 for x in _NonEagerInputs(op, xs_set): 757 pending_count[x.op] -= 1 758 ready = (pending_count[x.op] == 0) 759 if loop_state and not ready: 760 ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op) 761 if ready: 762 if control_flow_util.IsLoopExit(x.op): 763 # if x is an exit without real gradient, defer processing them. 764 grad_state = loop_state.GetGradState(x.op, before=False) 765 grad_state.deferred_exits.append(x) 766 grad_state.pending_exits_count -= 1 767 if grad_state.pending_exits_count == 0: 768 # We now have all the exits so process them. 769 has_not_none_grad = False 770 for y in grad_state.deferred_exits: 771 if _HasAnyNotNoneGrads(grads, y.op): 772 has_not_none_grad = True 773 queue.append(y.op) 774 else: 775 grad_state.unused_exits.append(y) 776 if has_not_none_grad: 777 # For an unused exit, if it has trainable outputs, backprop 778 # a zero gradient. Otherwise, just ignore it. 779 for y in grad_state.unused_exits: 780 if backprop_util.IsTrainable(y): 781 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) 782 queue.append(y.op) 783 else: 784 # All exits are "unused" so use None as gradient. 785 for y in grad_state.unused_exits: 786 queue.append(y.op) 787 else: 788 queue.append(x.op) 789 790 791def _SetGrad(grads, t, grad): 792 """Sets gradient "grad" in "grads" for tensor "t".""" 793 op = t.op 794 op_grads = grads.get(op) 795 if not op_grads: 796 op_grads = [[] for _ in xrange(len(op.outputs))] 797 grads[op] = op_grads 798 t_grads = op_grads[t.value_index] 799 if isinstance(t_grads, list): 800 t_grads.append(grad) 801 else: 802 assert control_flow_util.IsLoopSwitch(op) 803 op_grads[t.value_index] = grad 804 805 806def _ZerosLike(t): 807 t_dtype = default_gradient.get_zeros_dtype(t) 808 if t.dtype == dtypes.resource: 809 return array_ops.zeros( 810 resource_variable_ops.variable_shape(t), dtype=t_dtype) 811 else: 812 return array_ops.zeros_like(t, dtype=t_dtype) 813 814 815def _GetGrad(grads, t, unconnected_gradients): 816 """Gets gradient for tensor "t".""" 817 op = t.op 818 op_grads = grads.get(op) 819 if not op_grads: 820 if unconnected_gradients == UnconnectedGradients.ZERO: 821 return _ZerosLike(t) 822 elif unconnected_gradients == UnconnectedGradients.NONE: 823 return None 824 else: 825 raise ValueError( 826 "Unknown value for unconnected_gradients: %r" % unconnected_gradients) 827 828 t_grad = op_grads[t.value_index] 829 # This can happen if some other output of `t.op` has non-None grad. 830 if unconnected_gradients == UnconnectedGradients.ZERO and t_grad is None: 831 return _ZerosLike(t) 832 833 assert not isinstance( 834 t_grad, list), ("gradients list should have been aggregated by now.") 835 return t_grad 836 837 838def _GetGrads(grads, op): 839 """Gets all gradients for op.""" 840 if op in grads: 841 return grads[op] 842 else: 843 return [[] for _ in xrange(len(op.outputs))] 844 845 846def _AccumulatorShape(inputs): 847 shape = tensor_shape.unknown_shape() 848 for i in inputs: 849 if isinstance(i, ops.Tensor): 850 shape = shape.merge_with(i.get_shape()) 851 return shape 852 853 854def _LogOpGradients(op, out_grads, in_grads): 855 """Log the in and out grads of an op.""" 856 logging.vlog(1, "Gradient for '" + op.name + "'") 857 858 def _FilterGrad(x): 859 if x is None: 860 return False 861 if isinstance(x, (list, tuple)): 862 return bool(x) 863 else: 864 return True 865 866 logging.vlog(1, " in --> %s", 867 ", ".join(x.name for x in out_grads if _FilterGrad(x))) 868 logging.vlog(1, " out --> %s", 869 ", ".join(x.name for x in in_grads if _FilterGrad(x))) 870 871 872def _MultiDeviceAddN(tensor_list, gradient_uid): 873 """Adds tensors from potentially multiple devices.""" 874 # Basic function structure comes from control_flow_ops.group(). 875 # Sort tensors according to their devices. 876 tensors_on_device = collections.defaultdict(lambda: []) 877 for tensor in tensor_list: 878 tensors_on_device[tensor.device].append(tensor) 879 880 # For each device, add the tensors on that device first. 881 # Then gather the partial sums from multiple devices. 882 # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion. 883 # E.g., aggregate per GPU, then per task, and so on. 884 summands = [] 885 886 def DeviceKey(dev): 887 return "" if dev is None else dev 888 889 for dev in sorted(tensors_on_device, key=DeviceKey): 890 tensors = tensors_on_device[dev] 891 with ops._colocate_with_for_gradient( # pylint: disable=protected-access 892 tensors[0].op, 893 gradient_uid, 894 ignore_existing=True): 895 summands.append(math_ops.add_n(tensors)) 896 897 return math_ops.add_n(summands) 898 899 900@tf_export("AggregationMethod") 901class AggregationMethod(object): 902 """A class listing aggregation methods used to combine gradients. 903 904 Computing partial derivatives can require aggregating gradient 905 contributions. This class lists the various methods that can 906 be used to combine gradients in the graph. 907 908 The following aggregation methods are part of the stable API for 909 aggregating gradients: 910 911 * `ADD_N`: All of the gradient terms are summed as part of one 912 operation using the "AddN" op (see `tf.add_n`). This 913 method has the property that all gradients must be ready and 914 buffered separately in memory before any aggregation is performed. 915 * `DEFAULT`: The system-chosen default aggregation method. 916 917 The following aggregation methods are experimental and may not 918 be supported in future releases: 919 920 * `EXPERIMENTAL_TREE`: Gradient terms are summed in pairs using 921 the "AddN" op. This method of summing gradients may reduce 922 performance, but it can improve memory utilization because the 923 gradients can be released earlier. 924 925 """ 926 ADD_N = 0 927 DEFAULT = ADD_N 928 # The following are experimental and may not be supported in future releases. 929 EXPERIMENTAL_TREE = 1 930 EXPERIMENTAL_ACCUMULATE_N = 2 # An alias for EXPERIMENTAL_ADD_N = 1 931 932 933def _AggregatedGrads(grads, 934 op, 935 gradient_uid, 936 loop_state, 937 aggregation_method=None): 938 """Get the aggregated gradients for op. 939 940 Args: 941 grads: The map of memoized gradients. 942 op: The op to get gradients for. 943 gradient_uid: A unique identifier within the graph indicating 944 which invocation of gradients is being executed. Used to cluster 945 ops for compilation. 946 loop_state: An object for maintaining the state of the while loops in the 947 graph. It is of type ControlFlowState. None if the graph 948 contains no while loops. 949 aggregation_method: Specifies the method used to combine gradient terms. 950 Accepted values are constants defined in the class `AggregationMethod`. 951 952 Returns: 953 A list of gradients, one per each output of `op`. If the gradients 954 for a particular output is a list, this function aggregates it 955 before returning. 956 957 Raises: 958 TypeError: if the incoming grads are not Tensors or IndexedSlices. 959 ValueError: if the arguments are invalid. 960 961 """ 962 if aggregation_method is None: 963 aggregation_method = AggregationMethod.DEFAULT 964 if aggregation_method not in [ 965 AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE, 966 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 967 ]: 968 raise ValueError( 969 "Invalid aggregation_method specified %s." % aggregation_method) 970 out_grads = _GetGrads(grads, op) 971 for i, out_grad in enumerate(out_grads): 972 if loop_state: 973 if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)): 974 assert control_flow_util.IsLoopSwitch(op) 975 continue 976 # Grads have to be Tensors or IndexedSlices 977 if (isinstance(out_grad, collections_abc.Sequence) and not all( 978 isinstance(g, (ops.Tensor, ops.IndexedSlices)) 979 for g in out_grad 980 if g is not None)): 981 raise TypeError("gradients have to be either all Tensors " 982 "or all IndexedSlices") 983 # Aggregate multiple gradients, and convert [] to None. 984 if out_grad: 985 if len(out_grad) < 2: 986 used = "nop" 987 out_grads[i] = out_grad[0] 988 elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None): 989 tensor_shape = _AccumulatorShape(out_grad) 990 if aggregation_method in [ 991 AggregationMethod.EXPERIMENTAL_TREE, 992 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 993 ]: 994 # Aggregate all gradients by doing pairwise sums: this may 995 # reduce performance, but it can improve memory because the 996 # gradients can be released earlier. 997 # 998 # TODO(vrv): Consider replacing this with a version of 999 # tf.AddN() that eagerly frees its inputs as soon as they are 1000 # ready, so the order of this tree does not become a problem. 1001 used = "tree" 1002 with ops.name_scope(op.name + "_gradient_sum"): 1003 running_sum = out_grad[0] 1004 for grad in out_grad[1:]: 1005 running_sum = math_ops.add_n([running_sum, grad]) 1006 out_grads[i] = running_sum 1007 else: 1008 used = "add_n" 1009 out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid) 1010 logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad), 1011 tensor_shape, used) 1012 else: 1013 out_grads[i] = backprop.aggregate_indexed_slices_gradients(out_grad) # pylint: disable=protected-access 1014 else: # not out_grad 1015 # out_grads[i] is [], thus its aggregation is simply None. 1016 out_grads[i] = None 1017 return out_grads 1018 1019 1020# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are 1021# unfortunately too slow to use here. 1022POSSIBLE_GRADIENT_TYPES_NONE = 0 1023POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1 1024POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2 1025 1026 1027def PossibleTapeGradientTypes(tensors): 1028 """Determines whether and how `args` may require tape gradients.""" 1029 return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors) 1030