1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implements the graph generation for computation of gradients.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23 24from six.moves import xrange, zip # pylint: disable=redefined-builtin 25 26from tensorflow.core.framework import attr_value_pb2 27from tensorflow.python import pywrap_tfe 28from tensorflow.python.eager import backprop 29from tensorflow.python.eager import backprop_util 30from tensorflow.python.eager import context 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import function as framework_function 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import tensor_shape 36from tensorflow.python.framework.func_graph import FuncGraph 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import control_flow_ops 39from tensorflow.python.ops import control_flow_state 40from tensorflow.python.ops import control_flow_util 41from tensorflow.python.ops import default_gradient 42from tensorflow.python.ops import functional_ops 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import resource_variable_ops 45from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 46from tensorflow.python.platform import tf_logging as logging 47from tensorflow.python.util import compat 48from tensorflow.python.util import object_identity 49from tensorflow.python.util.compat import collections_abc 50from tensorflow.python.util.tf_export import tf_export 51 52 53def _MarkReachedOps(from_ops, reached_ops, func_graphs): 54 """Mark all ops reached from "from_ops". 55 56 Args: 57 from_ops: list of Operations. 58 reached_ops: set of Operations. 59 func_graphs: list of FuncGraphs. This method will traverse through 60 these functions if they capture from_ops or any reachable ops. 61 """ 62 queue = collections.deque() 63 queue.extend(from_ops) 64 while queue: 65 op = queue.popleft() 66 if op not in reached_ops: 67 reached_ops.add(op) 68 for output in op.outputs: 69 if _IsBackpropagatable(output): 70 queue.extend(_Consumers(output, func_graphs)) 71 72 73def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, 74 xs_set): 75 """Initialize the pending count for ops between two lists of Operations. 76 77 'pending_count[op]' indicates the number of backprop inputs 78 to this operation. 79 80 Args: 81 to_ops: list of Operations. 82 from_ops: list of Operations. 83 colocate_gradients_with_ops: Python bool. See docstring of gradients(). 84 func_graphs: list of FuncGraphs. This method will traverse through 85 these functions if they capture from_ops or any reachable ops. This is 86 useful if to_ops occur in a function and from_ops are in an outer function 87 or graph. 88 xs_set: ObjectIdentitySet of Tensors. 89 90 Returns: 91 A tuple containing: (1) the subset of to_ops reachable from from_ops by a 92 path of zero or more backpropagatable tensors, (2) a mapping from operation 93 to the number of backprop inputs to that op, and (3) a ControlFlowState 94 object which is not None if the ops between from_ops and to_ops contain 95 control flow loops. 96 """ 97 # Mark reachable ops from from_ops. 98 reached_ops = set() 99 _MarkReachedOps(from_ops, reached_ops, func_graphs) 100 # X in reached_ops iff X is reachable from from_ops by a path of zero or more 101 # backpropagatable tensors. 102 103 reachable_to_ops = set(op for op in to_ops if op in reached_ops) 104 105 # Mark between ops. 106 between_ops = set() 107 between_op_list = [] 108 queue = collections.deque() 109 queue.extend(to_ops) 110 while queue: 111 op = queue.popleft() 112 # We are interested in this op. 113 if op in reached_ops: 114 between_ops.add(op) 115 between_op_list.append(op) 116 # Clear the boolean so we won't add the inputs again. 117 reached_ops.remove(op) 118 for inp in _NonEagerInputs(op, xs_set): 119 queue.append(inp.op) 120 # X in between_ops iff X is on a path of zero or more backpropagatable tensors 121 # between from_ops and to_ops 122 123 # 'loop_state' is None if there are no while loops. 124 loop_state = control_flow_state.MaybeCreateControlFlowState( 125 between_op_list, between_ops, colocate_gradients_with_ops) 126 127 # Initialize pending count for between ops. 128 pending_count = collections.defaultdict(int) 129 for op in between_op_list: 130 for x in _NonEagerInputs(op, xs_set): 131 if x.op in between_ops: 132 pending_count[x.op] += 1 133 134 return reachable_to_ops, pending_count, loop_state 135 136 137def _AsList(x): 138 return x if isinstance(x, (list, tuple)) else [x] 139 140 141def _DefaultGradYs(grad_ys, 142 ys, 143 colocate_gradients_with_ops, 144 gradient_uid="__unsupported__"): 145 """Fill in default values for grad_ys. 146 147 Args: 148 grad_ys: List of gradients, can contain None. 149 ys: List of tensors. 150 colocate_gradients_with_ops: If True, try colocating gradients with 151 the corresponding op. 152 gradient_uid: A unique identifier within the graph indicating 153 which invocation of gradients is being executed. Used to cluster 154 ops for compilation. 155 156 Returns: 157 A list of gradients to use, without None. 158 159 Raises: 160 ValueError: If sizes of gradients and inputs don't match 161 TypeError: If type of any gradient is not valid for its input. 162 """ 163 if len(grad_ys) != len(ys): 164 raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys))) 165 grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y") 166 new_grad_ys = [] 167 for i, (y, grad_y) in enumerate(zip(ys, grad_ys)): 168 with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops): 169 if grad_y is None: 170 if y.dtype.is_complex: 171 raise TypeError( 172 "Gradients of complex tensors must set grad_ys (y.dtype = %r)" % 173 y.dtype) 174 new_grad_ys.append( 175 array_ops.fill( 176 array_ops.shape(y), 177 constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i))) 178 continue 179 if y.dtype.is_floating or y.dtype.is_integer: 180 if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer: 181 raise TypeError( 182 "Gradient type %s generated for real or " 183 "integer-valued tensor %s with type %s must be " 184 "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y, 185 dtypes.as_dtype(y.dtype).name)) 186 elif y.dtype.is_complex: 187 if not grad_y.dtype.is_complex: 188 raise TypeError( 189 "Gradient type %s generated for complex-valued " 190 "tensor %s with type %s must be real" % (dtypes.as_dtype( 191 grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) 192 elif y.dtype == dtypes.variant: 193 if grad_y.dtype != dtypes.variant: 194 raise TypeError( 195 "Gradient type %s generated for variant " 196 "tensor %s with type %s must be variant" % (dtypes.as_dtype( 197 grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) 198 elif y.dtype == dtypes.resource: 199 # We assume y is the handle of a ResourceVariable. The gradient of a 200 # ResourceVariable should be a numeric value, not another resource. 201 if grad_y.dtype == dtypes.resource: 202 raise TypeError("Input gradient %s for resource tensor %s should not " 203 "be a resource" % (grad_y, y)) 204 else: 205 raise TypeError( 206 "Tensor %s with type %s must be numeric " 207 "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name)) 208 # Create a grad_y tensor in the name scope of the gradient. 209 # Required for TensorArrays to identify which gradient call a 210 # grad_y value is coming from. 211 if isinstance(grad_y, ops.IndexedSlices): 212 new_grad_ys.append( 213 ops.IndexedSlices( 214 indices=(array_ops.identity( 215 grad_y.indices, name="grad_ys_%d_indices" % i) 216 if isinstance(grad_y.indices, ops.Tensor) else 217 grad_y.indices), 218 values=(array_ops.identity( 219 grad_y.values, name="grad_ys_%d_values" % i) if isinstance( 220 grad_y.values, ops.Tensor) else grad_y.values), 221 dense_shape=(array_ops.identity( 222 grad_y.dense_shape, name="grad_ys_%d_shape" % i) 223 if isinstance(grad_y.dense_shape, ops.Tensor) else 224 grad_y.dense_shape))) 225 else: 226 new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i)) 227 228 return new_grad_ys 229 230 231def _IsBackpropagatable(tensor): 232 if backprop_util.IsTrainable(tensor): 233 return True 234 dtype = dtypes.as_dtype(tensor.dtype) 235 return dtype.base_dtype == dtypes.bfloat16 236 237 238def _VerifyGeneratedGradients(grads, op): 239 """Verify that gradients are valid in number and type. 240 241 Args: 242 grads: List of generated gradients. 243 op: Operation for which the gradients where generated. 244 245 Raises: 246 ValueError: if sizes of gradients and inputs don't match. 247 TypeError: if type of any gradient is not valid for its input. 248 """ 249 # While ops have inputs added to them during the gradient computation, so we 250 # skip the below check. See while_v2 for details. 251 if op.type == "While" or op.type == "StatelessWhile": 252 return 253 254 if len(grads) != len(op.inputs): 255 raise ValueError("Num gradients %d generated for op %s do not match num " 256 "inputs %d" % (len(grads), op.node_def, len(op.inputs))) 257 258 259def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set): 260 """The set of ops that terminate the gradient computation. 261 262 This computes the frontier of the forward graph *before* which backprop 263 should stop. Operations in the returned set will not be differentiated. 264 This set is defined as the subset of `from_ops` containing ops that have 265 no predecessor in `from_ops`. `pending_count` is the result of 266 `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops` 267 iff pending_count[op] > 0. 268 269 In addition, none of `stop_gradient_ops` will be differentiated. 270 271 Args: 272 from_ops: list of Operations. 273 stop_gradient_ops: list of Operations never to backprop through. 274 pending_count: mapping from operation to number of backprop inputs. 275 xs_set: ObjectIdentitySet of Tensors. 276 277 Returns: 278 The set of operations. 279 """ 280 stop_ops = set() 281 for op in from_ops: 282 is_stop_op = True 283 for inp in _NonEagerInputs(op, xs_set): 284 if pending_count[inp.op] > 0: 285 is_stop_op = False 286 break 287 if is_stop_op: 288 stop_ops.add(op) 289 stop_ops.update(op for op in stop_gradient_ops) 290 return stop_ops 291 292 293@contextlib.contextmanager 294def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pylint: disable=invalid-name 295 """Context to colocate with `op` if `colocate_gradients_with_ops`.""" 296 if colocate_gradients_with_ops: 297 with ops._colocate_with_for_gradient(op, gradient_uid): # pylint: disable=protected-access 298 yield 299 else: 300 yield 301 302 303def _IsPartitionedCall(op): 304 return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall" 305 306 307def _SymGrad(op, out_grads): 308 """Backprop through a function call node op given its outputs' gradients.""" 309 f_in = [x for x in op.inputs] + out_grads 310 f_types = [default_gradient.get_zeros_dtype(x) for x in op.inputs] 311 f = attr_value_pb2.NameAttrList() 312 if _IsPartitionedCall(op): 313 f.name = op.get_attr("f").name 314 else: 315 f.name = op.type 316 for k in op.node_def.attr: 317 f.attr[k].CopyFrom(op.node_def.attr[k]) 318 in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f) 319 return in_grads 320 321 322def _MaybeCompile(scope, op, func, grad_fn): 323 """Compile the calculation in grad_fn if op was marked as compiled.""" 324 scope = scope.rstrip("/").replace("/", "_") 325 if func is not None: 326 xla_compile = func.definition.attr["_XlaCompile"].b 327 xla_separate_compiled_gradients = func.definition.attr[ 328 "_XlaSeparateCompiledGradients"].b 329 xla_scope = func.definition.attr["_XlaScope"].s.decode() 330 else: 331 try: 332 xla_compile = op.get_attr("_XlaCompile") 333 xla_separate_compiled_gradients = op.get_attr( 334 "_XlaSeparateCompiledGradients") 335 xla_scope = op.get_attr("_XlaScope").decode() 336 except ValueError: 337 xla_compile = False 338 339 if not xla_compile: 340 return grad_fn() # Exit early 341 342 # If the gradients are supposed to be compiled separately, we give them a 343 # _XlaScope name that is based on the name_scope of the gradients. Otherwise 344 # they just inherit the existing _XlaScope name, which lets them be merged 345 # together with the non-gradient computation. 346 if xla_separate_compiled_gradients: 347 xla_grad_scope = "%s_grad_%s" % (xla_scope, scope) 348 else: 349 xla_grad_scope = xla_scope 350 351 attrs = { 352 "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile), 353 "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode()) 354 } 355 with ops.get_default_graph()._attr_scope(attrs): # pylint: disable=protected-access 356 return grad_fn() 357 358 359def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set): 360 """Raises an error if we backprop through a loop var.""" 361 # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error 362 # message. 363 target_op = None 364 queue = collections.deque([op]) 365 visited = set() 366 while queue: 367 curr_op = queue.popleft() 368 if curr_op in visited: continue 369 visited.add(curr_op) 370 if curr_op in from_ops: 371 target_op = curr_op 372 break 373 queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set)) 374 assert target_op 375 raise ValueError( 376 "Cannot compute gradient inside while loop with respect to op '%s'. " 377 "We do not support taking the gradient wrt or through the initial value " 378 "of a loop variable. Gradients can be computed through loop invariants " 379 "or wrt the input parameters to the loop body." 380 % target_op.name) 381 382 383def _IsFunction(graph): 384 return (isinstance(graph, FuncGraph) or 385 isinstance(graph, framework_function._FuncGraph)) # pylint: disable=protected-access 386 387 388def _Captures(func_graph): 389 if isinstance(func_graph, FuncGraph): 390 return func_graph.captures 391 else: 392 assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access 393 return func_graph.captures 394 395 396def _MaybeCaptured(t): 397 """If t is a captured value placeholder, returns the original captured value. 398 399 Args: 400 t: Tensor 401 402 Returns: 403 A tensor, potentially from a different Graph/FuncGraph. 404 """ 405 # pylint: disable=protected-access 406 if (not isinstance(t, ops.EagerTensor) and 407 _IsFunction(t.op.graph) and t.op.type == "Placeholder"): 408 for input_t, placeholder_t in _Captures(t.op.graph): 409 if t is placeholder_t: 410 return _MaybeCaptured(input_t) 411 # pylint: enable=protected-access 412 return t 413 414 415def _NonEagerInputs(op, xs_set): 416 """Returns the inputs of op, crossing closure boundaries where necessary. 417 418 Does not return any captured EagerTensors, i.e., the number of tensors 419 returned may be less than the actual number of inputs. 420 421 Args: 422 op: Operation 423 xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t. 424 425 Returns: 426 A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op 427 is in a FuncGraph and has captured inputs. 428 """ 429 return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)] 430 431 432# TODO(skyewm): plumbing xs through everywhere is ugly, consider making 433# _GradientsHelper a class with xs as a member variable. 434def _Inputs(op, xs_set): 435 """Returns the inputs of op, crossing closure boundaries where necessary. 436 437 Args: 438 op: Operation 439 xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t. 440 441 Returns: 442 A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op 443 is in a FuncGraph and has captured inputs. 444 """ 445 if _IsFunction(op.graph): # pylint: disable=protected-access 446 inputs = [] 447 for t in op.inputs: 448 # If we're differentiating w.r.t. `t`, do not attempt to traverse through 449 # it to a captured value. The algorithm needs to "see" `t` in this case, 450 # even if it's a function input for a captured value, whereas usually we'd 451 # like to traverse through these closures as if the captured value was the 452 # direct input to op. 453 if t not in xs_set: 454 t = _MaybeCaptured(t) 455 inputs.append(t) 456 return inputs 457 else: 458 return op.inputs 459 460 461def _Consumers(t, func_graphs): 462 """Returns the consumers of t, crossing closure boundaries where necessary. 463 464 Args: 465 t: Tensor 466 func_graphs: a list of FuncGraphs that may have captured t. 467 468 Returns: 469 A list of tensors. The tensors will be from the current graph and/or 470 func_graphs. 471 """ 472 consumers = t.consumers() 473 for func in func_graphs: 474 for input_t, placeholder in _Captures(func): 475 if input_t is t: 476 consumers.extend(_Consumers(placeholder, func_graphs)) 477 return consumers 478 479 480def _GradientsHelper(ys, 481 xs, 482 grad_ys=None, 483 name="gradients", 484 colocate_gradients_with_ops=False, 485 gate_gradients=False, 486 aggregation_method=None, 487 stop_gradients=None, 488 unconnected_gradients=UnconnectedGradients.NONE, 489 src_graph=None): 490 """Implementation of gradients().""" 491 if context.executing_eagerly(): 492 raise RuntimeError("tf.gradients is not supported when eager execution " 493 "is enabled. Use tf.GradientTape instead.") 494 if src_graph is None: 495 src_graph = ops.get_default_graph() 496 try: 497 unconnected_gradients = UnconnectedGradients(unconnected_gradients) 498 except ValueError: 499 raise ValueError( 500 "Unknown value for unconnected_gradients: %r" % unconnected_gradients) 501 502 # If src_graph is a _FuncGraph (i.e. a function body), gather it and all 503 # ancestor graphs. This is necessary for correctly handling captured values. 504 func_graphs = [] 505 curr_graph = src_graph 506 while _IsFunction(curr_graph): 507 func_graphs.append(curr_graph) 508 if isinstance(curr_graph, FuncGraph): 509 curr_graph = curr_graph.outer_graph 510 else: 511 assert isinstance(curr_graph, framework_function._FuncGraph) # pylint: disable=protected-access 512 curr_graph = curr_graph._outer_graph # pylint: disable=protected-access 513 514 ys = _AsList(ys) 515 xs = _AsList(xs) 516 stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients) 517 if grad_ys is None: 518 grad_ys = [None] * len(ys) 519 else: 520 grad_ys = _AsList(grad_ys) 521 522 with ops.name_scope( 523 name, "gradients", 524 list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope: 525 # Get a uid for this call to gradients that can be used to help 526 # cluster ops for compilation. 527 gradient_uid = ops.get_default_graph().unique_name("uid") 528 ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y") 529 xs = [ 530 x.handle if resource_variable_ops.is_resource_variable(x) else x 531 for x in xs 532 ] 533 xs = ops.internal_convert_n_to_tensor_or_indexed_slices( 534 xs, name="x", as_ref=True) 535 xs_set = object_identity.ObjectIdentitySet(xs) 536 grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, 537 gradient_uid) 538 539 # The approach we take here is as follows: Create a list of all ops in the 540 # subgraph between the ys and xs. Visit these ops in reverse order of ids 541 # to ensure that when we visit an op the gradients w.r.t its outputs have 542 # been collected. Then aggregate these gradients if needed, call the op's 543 # gradient function, and add the generated gradients to the gradients for 544 # its input. 545 546 # Initialize the pending count for ops in the connected subgraph from ys 547 # to the xs. 548 to_ops = [t.op for t in ys] 549 from_ops = [t.op for t in xs] 550 stop_gradient_ops = [t.op for t in stop_gradients] 551 reachable_to_ops, pending_count, loop_state = _PendingCount( 552 to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set) 553 554 # Iterate over the collected ops. 555 # 556 # grads: op => list of gradients received on each output endpoint of the 557 # op. The gradients for each endpoint are initially collected as a list. 558 # When it is time to call the op's gradient function, for each endpoint we 559 # aggregate the list of received gradients into a Add() Operation if there 560 # is more than one. 561 grads = {} 562 563 # Add the initial gradients for the ys. 564 for y, grad_y in zip(ys, grad_ys): 565 _SetGrad(grads, y, grad_y) 566 567 # Initialize queue with to_ops. 568 queue = collections.deque() 569 # Add the ops in 'to_ops' into the queue. 570 to_ops_set = set() 571 for op in to_ops: 572 # 'ready' handles the case where one output gradient relies on 573 # another output's gradient. 574 ready = (pending_count[op] == 0) 575 if ready and op not in to_ops_set and op in reachable_to_ops: 576 to_ops_set.add(op) 577 queue.append(op) 578 579 if loop_state: 580 loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set) 581 for y in loop_exits: 582 if backprop_util.IsTrainable(y): 583 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) 584 queue.append(y.op) 585 586 stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set) 587 while queue: 588 # generate gradient subgraph for op. 589 op = queue.popleft() 590 with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): 591 if loop_state: 592 loop_state.EnterGradWhileContext(op, before=True) 593 out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, 594 aggregation_method) 595 if loop_state: 596 loop_state.ExitGradWhileContext(op, before=True) 597 598 grad_fn = None 599 func_call = None 600 is_partitioned_call = _IsPartitionedCall(op) 601 # pylint: disable=protected-access 602 is_func_call = ( 603 src_graph._is_function(op.type) or is_partitioned_call) 604 # pylint: enable=protected-access 605 has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) 606 if has_out_grads and (op not in stop_ops): 607 try: 608 grad_fn = ops.get_gradient_function(op) 609 except LookupError: 610 if is_func_call: 611 if is_partitioned_call: 612 func_name = compat.as_bytes(op.get_attr("f").name) 613 func_call = src_graph._get_function( # pylint: disable=protected-access 614 func_name) 615 # When a graph is imported, the FunctionDefs are not copied over 616 # to each sub-graph so we recursively search the outer graphs 617 # for the FunctionDef. 618 if not func_call and hasattr(src_graph, "outer_graph"): 619 graph = src_graph.outer_graph 620 while graph is not None: 621 func_call = graph._get_function(func_name) # pylint: disable=protected-access 622 if func_call is not None: 623 break 624 if hasattr(graph, "outer_graph"): 625 graph = graph.outer_graph 626 else: 627 break 628 else: 629 func_call = src_graph._get_function(op.type) # pylint: disable=protected-access 630 # Note that __defun is not set if the graph is 631 # imported. If it's set, we prefer to access the original 632 # defun. 633 func_call = getattr(op, "__defun", func_call) 634 grad_fn = func_call.python_grad_func 635 else: 636 raise LookupError( 637 "No gradient defined for operation '%s' (op type: %s)" % 638 (op.name, op.type)) 639 if loop_state: 640 loop_state.EnterGradWhileContext(op, before=False) 641 642 # NOTE(skyewm): We don't support computing gradients wrt a loop variable 643 # unless it's within the context of a single iteration (i.e. the 644 # gradient is wrt to the loop parameter in the body function, not wrt or 645 # through the initial value). This means if we're in a while loop 646 # context, we should never see a switch node from this context. 647 # pylint: disable=protected-access 648 if (control_flow_util.IsSwitch(op) and 649 op._control_flow_context is not None and 650 op._control_flow_context.IsWhileContext() and 651 op._control_flow_context == 652 ops.get_default_graph()._get_control_flow_context()): 653 _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set) 654 # pylint: enable=protected-access 655 656 if (grad_fn or is_func_call) and has_out_grads: 657 # NOTE: If _AggregatedGrads didn't compute a value for the i'th 658 # output, it means that the cost does not depend on output[i], 659 # therefore dC/doutput[i] is 0. 660 for i, out_grad in enumerate(out_grads): 661 if (not isinstance(out_grad, ops.Tensor) and not out_grad) and ( 662 (not grad_fn and is_func_call) 663 or backprop_util.IsTrainable(op.outputs[i])): 664 # Only trainable outputs or outputs for a function call that 665 # will use SymbolicGradient get a zero gradient. Gradient 666 # functions should ignore the gradient for other outputs. 667 # TODO(apassos) gradients of resource handles might be an 668 # issue here because of zeros. 669 if loop_state: 670 out_grads[i] = loop_state.ZerosLikeV1WhileLoop(op, i) 671 elif default_gradient.supports_default_grad(op.outputs[i]): 672 # TODO(b/143286622): The supports_default_grad check is needed 673 # because While op emits non-differentiable resource tensors 674 # as outputs. Remove this check when that is not the case. 675 out_grads[i] = control_flow_state.ZerosLike(op, i) 676 with ops.name_scope(op.name + "_grad"): 677 # pylint: disable=protected-access 678 with src_graph._original_op(op): 679 # pylint: enable=protected-access 680 if grad_fn: 681 # If grad_fn was found, do not use SymbolicGradient even for 682 # functions. 683 in_grads = _MaybeCompile(grad_scope, op, func_call, 684 lambda: grad_fn(op, *out_grads)) 685 else: 686 # For function call ops, we add a 'SymbolicGradient' 687 # node to the graph to compute gradients. 688 in_grads = _MaybeCompile(grad_scope, op, func_call, 689 lambda: _SymGrad(op, out_grads)) 690 in_grads = _AsList(in_grads) 691 _VerifyGeneratedGradients(in_grads, op) 692 if gate_gradients and len([x for x in in_grads 693 if x is not None]) > 1: 694 with ops.device(None): 695 with ops._colocate_with_for_gradient( # pylint: disable=protected-access 696 None, 697 gradient_uid, 698 ignore_existing=True): 699 in_grads = control_flow_ops.tuple(in_grads) 700 _LogOpGradients(op, out_grads, in_grads) 701 else: 702 # If no grad_fn is defined or none of out_grads is available, 703 # just propagate a list of None backwards. 704 in_grads = [None] * len(_Inputs(op, xs_set)) 705 # Note: we don't filter out eager inputs here because the inputs need to 706 # line up with in_grads. 707 for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)): 708 if in_grad is not None: 709 if (isinstance(in_grad, ops.Tensor) and 710 t_in.dtype != dtypes.resource): 711 try: 712 in_grad.set_shape(t_in.get_shape()) 713 except ValueError: 714 raise ValueError( 715 "Incompatible shapes between op input and calculated " 716 "input gradient. Forward operation: %s. Input index: %d. " 717 "Original input shape: %s. " 718 "Calculated input gradient shape: %s" % 719 (op.name, i, t_in.shape, in_grad.shape)) 720 if not isinstance(t_in, ops.EagerTensor): 721 _SetGrad(grads, t_in, in_grad) 722 if loop_state: 723 loop_state.ExitGradWhileContext(op, before=False) 724 725 # Update pending count for the inputs of op and enqueue ready ops. 726 _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, 727 xs_set) 728 729 if loop_state: 730 loop_state.PostProcessing() 731 return [_GetGrad(grads, x, unconnected_gradients) for x in xs] 732 733 734def _HasAnyNotNoneGrads(grads, op): 735 """Return true iff op has real gradient.""" 736 out_grads = _GetGrads(grads, op) 737 for out_grad in out_grads: 738 if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)): 739 return True 740 if out_grad and isinstance(out_grad, collections_abc.Sequence): 741 if any(g is not None for g in out_grad): 742 return True 743 return False 744 745 746def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, 747 xs_set): 748 """Update pending count for the inputs of op and enqueue ready ops.""" 749 for x in _NonEagerInputs(op, xs_set): 750 pending_count[x.op] -= 1 751 ready = (pending_count[x.op] == 0) 752 if loop_state and not ready: 753 ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op) 754 if ready: 755 if control_flow_util.IsLoopExit(x.op): 756 # if x is an exit without real gradient, defer processing them. 757 grad_state = loop_state.GetGradState(x.op, before=False) 758 grad_state.deferred_exits.append(x) 759 grad_state.pending_exits_count -= 1 760 if grad_state.pending_exits_count == 0: 761 # We now have all the exits so process them. 762 has_not_none_grad = False 763 for y in grad_state.deferred_exits: 764 if _HasAnyNotNoneGrads(grads, y.op): 765 has_not_none_grad = True 766 queue.append(y.op) 767 else: 768 grad_state.unused_exits.append(y) 769 if has_not_none_grad: 770 # For an unused exit, if it has trainable outputs, backprop 771 # a zero gradient. Otherwise, just ignore it. 772 for y in grad_state.unused_exits: 773 if backprop_util.IsTrainable(y): 774 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) 775 queue.append(y.op) 776 else: 777 # All exits are "unused" so use None as gradient. 778 for y in grad_state.unused_exits: 779 queue.append(y.op) 780 else: 781 queue.append(x.op) 782 783 784def _SetGrad(grads, t, grad): 785 """Sets gradient "grad" in "grads" for tensor "t".""" 786 op = t.op 787 op_grads = grads.get(op) 788 if not op_grads: 789 op_grads = [[] for _ in xrange(len(op.outputs))] 790 grads[op] = op_grads 791 t_grads = op_grads[t.value_index] 792 if isinstance(t_grads, list): 793 t_grads.append(grad) 794 else: 795 assert control_flow_util.IsLoopSwitch(op) 796 op_grads[t.value_index] = grad 797 798 799def _ZerosLike(t): 800 t_dtype = default_gradient.get_zeros_dtype(t) 801 if t.dtype == dtypes.resource: 802 return array_ops.zeros( 803 resource_variable_ops.variable_shape(t), dtype=t_dtype) 804 else: 805 return array_ops.zeros_like(t, dtype=t_dtype) 806 807 808def _GetGrad(grads, t, unconnected_gradients): 809 """Gets gradient for tensor "t".""" 810 op = t.op 811 op_grads = grads.get(op) 812 if not op_grads: 813 if unconnected_gradients == UnconnectedGradients.ZERO: 814 return _ZerosLike(t) 815 elif unconnected_gradients == UnconnectedGradients.NONE: 816 return None 817 else: 818 raise ValueError( 819 "Unknown value for unconnected_gradients: %r" % unconnected_gradients) 820 821 t_grad = op_grads[t.value_index] 822 # This can happen if some other output of `t.op` has non-None grad. 823 if unconnected_gradients == UnconnectedGradients.ZERO and t_grad is None: 824 return _ZerosLike(t) 825 826 assert not isinstance( 827 t_grad, list), ("gradients list should have been aggregated by now.") 828 return t_grad 829 830 831def _GetGrads(grads, op): 832 """Gets all gradients for op.""" 833 if op in grads: 834 return grads[op] 835 else: 836 return [[] for _ in xrange(len(op.outputs))] 837 838 839def _AccumulatorShape(inputs): 840 shape = tensor_shape.unknown_shape() 841 for i in inputs: 842 if isinstance(i, ops.Tensor): 843 shape = shape.merge_with(i.get_shape()) 844 return shape 845 846 847def _LogOpGradients(op, out_grads, in_grads): 848 """Log the in and out grads of an op.""" 849 logging.vlog(1, "Gradient for '" + op.name + "'") 850 851 def _FilterGrad(x): 852 if x is None: 853 return False 854 if isinstance(x, (list, tuple)): 855 return bool(x) 856 else: 857 return True 858 859 logging.vlog(1, " in --> %s", 860 ", ".join(x.name for x in out_grads if _FilterGrad(x))) 861 logging.vlog(1, " out --> %s", 862 ", ".join(x.name for x in in_grads if _FilterGrad(x))) 863 864 865def _MultiDeviceAddN(tensor_list, gradient_uid): 866 """Adds tensors from potentially multiple devices.""" 867 # Basic function structure comes from control_flow_ops.group(). 868 # Sort tensors according to their devices. 869 tensors_on_device = collections.defaultdict(lambda: []) 870 for tensor in tensor_list: 871 tensors_on_device[tensor.device].append(tensor) 872 873 # For each device, add the tensors on that device first. 874 # Then gather the partial sums from multiple devices. 875 # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion. 876 # E.g., aggregate per GPU, then per task, and so on. 877 summands = [] 878 879 def DeviceKey(dev): 880 return "" if dev is None else dev 881 882 for dev in sorted(tensors_on_device, key=DeviceKey): 883 tensors = tensors_on_device[dev] 884 with ops._colocate_with_for_gradient( # pylint: disable=protected-access 885 tensors[0].op, 886 gradient_uid, 887 ignore_existing=True): 888 summands.append(math_ops.add_n(tensors)) 889 890 return math_ops.add_n(summands) 891 892 893@tf_export("AggregationMethod") 894class AggregationMethod(object): 895 """A class listing aggregation methods used to combine gradients. 896 897 Computing partial derivatives can require aggregating gradient 898 contributions. This class lists the various methods that can 899 be used to combine gradients in the graph. 900 901 The following aggregation methods are part of the stable API for 902 aggregating gradients: 903 904 * `ADD_N`: All of the gradient terms are summed as part of one 905 operation using the "AddN" op (see `tf.add_n`). This 906 method has the property that all gradients must be ready and 907 buffered separately in memory before any aggregation is performed. 908 * `DEFAULT`: The system-chosen default aggregation method. 909 910 The following aggregation methods are experimental and may not 911 be supported in future releases: 912 913 * `EXPERIMENTAL_TREE`: Gradient terms are summed in pairs using 914 the "AddN" op. This method of summing gradients may reduce 915 performance, but it can improve memory utilization because the 916 gradients can be released earlier. 917 918 """ 919 ADD_N = 0 920 DEFAULT = ADD_N 921 # The following are experimental and may not be supported in future releases. 922 EXPERIMENTAL_TREE = 1 923 EXPERIMENTAL_ACCUMULATE_N = 2 # An alias for EXPERIMENTAL_ADD_N = 1 924 925 926def _AggregatedGrads(grads, 927 op, 928 gradient_uid, 929 loop_state, 930 aggregation_method=None): 931 """Get the aggregated gradients for op. 932 933 Args: 934 grads: The map of memoized gradients. 935 op: The op to get gradients for. 936 gradient_uid: A unique identifier within the graph indicating 937 which invocation of gradients is being executed. Used to cluster 938 ops for compilation. 939 loop_state: An object for maintaining the state of the while loops in the 940 graph. It is of type ControlFlowState. None if the graph 941 contains no while loops. 942 aggregation_method: Specifies the method used to combine gradient terms. 943 Accepted values are constants defined in the class `AggregationMethod`. 944 945 Returns: 946 A list of gradients, one per each output of `op`. If the gradients 947 for a particular output is a list, this function aggregates it 948 before returning. 949 950 Raises: 951 TypeError: if the incoming grads are not Tensors or IndexedSlices. 952 ValueError: if the arguments are invalid. 953 954 """ 955 if aggregation_method is None: 956 aggregation_method = AggregationMethod.DEFAULT 957 if aggregation_method not in [ 958 AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE, 959 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 960 ]: 961 raise ValueError( 962 "Invalid aggregation_method specified %s." % aggregation_method) 963 out_grads = _GetGrads(grads, op) 964 for i, out_grad in enumerate(out_grads): 965 if loop_state: 966 if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)): 967 assert control_flow_util.IsLoopSwitch(op) 968 continue 969 # Grads have to be Tensors or IndexedSlices 970 if (isinstance(out_grad, collections_abc.Sequence) and not all( 971 isinstance(g, (ops.Tensor, ops.IndexedSlices)) 972 for g in out_grad 973 if g is not None)): 974 raise TypeError("gradients have to be either all Tensors " 975 "or all IndexedSlices") 976 # Aggregate multiple gradients, and convert [] to None. 977 if out_grad: 978 if len(out_grad) < 2: 979 used = "nop" 980 out_grads[i] = out_grad[0] 981 elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None): 982 tensor_shape = _AccumulatorShape(out_grad) 983 if aggregation_method in [ 984 AggregationMethod.EXPERIMENTAL_TREE, 985 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 986 ]: 987 # Aggregate all gradients by doing pairwise sums: this may 988 # reduce performance, but it can improve memory because the 989 # gradients can be released earlier. 990 # 991 # TODO(vrv): Consider replacing this with a version of 992 # tf.AddN() that eagerly frees its inputs as soon as they are 993 # ready, so the order of this tree does not become a problem. 994 used = "tree" 995 with ops.name_scope(op.name + "_gradient_sum"): 996 running_sum = out_grad[0] 997 for grad in out_grad[1:]: 998 running_sum = math_ops.add_n([running_sum, grad]) 999 out_grads[i] = running_sum 1000 else: 1001 used = "add_n" 1002 out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid) 1003 logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad), 1004 tensor_shape, used) 1005 else: 1006 out_grads[i] = backprop.aggregate_indexed_slices_gradients(out_grad) # pylint: disable=protected-access 1007 else: # not out_grad 1008 # out_grads[i] is [], thus its aggregation is simply None. 1009 out_grads[i] = None 1010 return out_grads 1011 1012 1013# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are 1014# unfortunately too slow to use here. 1015POSSIBLE_GRADIENT_TYPES_NONE = 0 1016POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1 1017POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2 1018 1019 1020def PossibleTapeGradientTypes(tensors): 1021 """Determines whether and how `args` may require tape gradients.""" 1022 return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors) 1023