1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implements the graph generation for computation of gradients.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import warnings 24 25import numpy as np 26from six.moves import xrange # pylint: disable=redefined-builtin 27 28from tensorflow.core.framework import attr_value_pb2 29from tensorflow.python.eager import context 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import function as framework_function 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.framework.func_graph import FuncGraph 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import control_flow_ops 39from tensorflow.python.ops import control_flow_util 40from tensorflow.python.ops import functional_ops 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops import resource_variable_ops 43from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.util import compat 46from tensorflow.python.util.tf_export import tf_export 47 48 49# Warn the user if we convert a sparse representation to dense with at 50# least this number of elements. 51_LARGE_SPARSE_NUM_ELEMENTS = 100000000 52 53 54def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False): 55 """Converts an IndexedSlices object `value` to a Tensor. 56 57 NOTE(mrry): This function is potentially expensive. 58 59 Args: 60 value: An ops.IndexedSlices object. 61 dtype: The dtype of the Tensor to be returned. 62 name: Optional name to use for the returned Tensor. 63 as_ref: True if a ref is requested. 64 65 Returns: 66 A dense Tensor representing the values in the given IndexedSlices. 67 68 Raises: 69 ValueError: If the IndexedSlices does not have the same dtype. 70 """ 71 _ = as_ref 72 if dtype and not dtype.is_compatible_with(value.dtype): 73 raise ValueError( 74 "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" % 75 (dtype.name, value.dtype.name)) 76 if value.dense_shape is None: 77 raise ValueError( 78 "Tensor conversion requested for IndexedSlices without dense_shape: %s" 79 % str(value)) 80 # TODO(mrry): Consider adding static shape information to 81 # IndexedSlices, to avoid using numpy here. 82 if not context.executing_eagerly(): 83 dense_shape_value = tensor_util.constant_value(value.dense_shape) 84 if dense_shape_value is not None: 85 num_elements = np.prod(dense_shape_value) 86 if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS: 87 warnings.warn( 88 "Converting sparse IndexedSlices to a dense Tensor with %d " 89 "elements. This may consume a large amount of memory." % 90 num_elements) 91 else: 92 warnings.warn( 93 "Converting sparse IndexedSlices to a dense Tensor of unknown shape. " 94 "This may consume a large amount of memory.") 95 return math_ops.unsorted_segment_sum( 96 value.values, value.indices, value.dense_shape[0], name=name) 97 98 99ops.register_tensor_conversion_function(ops.IndexedSlices, 100 _IndexedSlicesToTensor) 101 102 103def _MarkReachedOps(from_ops, reached_ops, func_graphs): 104 """Mark all ops reached from "from_ops". 105 106 Args: 107 from_ops: list of Operations. 108 reached_ops: set of Operations. 109 func_graphs: list of FuncGraphs. This method will traverse through 110 these functions if they capture from_ops or any reachable ops. 111 """ 112 queue = collections.deque() 113 queue.extend(from_ops) 114 while queue: 115 op = queue.popleft() 116 if op not in reached_ops: 117 reached_ops.add(op) 118 for output in op.outputs: 119 if _IsBackpropagatable(output): 120 queue.extend(_Consumers(output, func_graphs)) 121 122 123def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, 124 xs): 125 """Initialize the pending count for ops between two lists of Operations. 126 127 'pending_count[op]' indicates the number of backprop inputs 128 to this operation. 129 130 Args: 131 to_ops: list of Operations. 132 from_ops: list of Operations. 133 colocate_gradients_with_ops: Python bool. See docstring of gradients(). 134 func_graphs: list of FuncGraphs. This method will traverse through 135 these functions if they capture from_ops or any reachable ops. This is 136 useful if to_ops occur in a function and from_ops are in an outer function 137 or graph. 138 xs: list of Tensors. 139 140 Returns: 141 A tuple containing: (1) the subset of to_ops reachable from from_ops by a 142 path of zero or more backpropagatable tensors, (2) a mapping from operation 143 to the number of backprop inputs to that op, and (3) a ControlFlowState 144 object which is not None if the ops between from_ops and to_ops contain 145 control flow loops. 146 """ 147 # Mark reachable ops from from_ops. 148 reached_ops = set() 149 _MarkReachedOps(from_ops, reached_ops, func_graphs) 150 # X in reached_ops iff X is reachable from from_ops by a path of zero or more 151 # backpropagatable tensors. 152 153 reachable_to_ops = set(op for op in to_ops if op in reached_ops) 154 155 # Mark between ops. 156 between_ops = set() 157 between_op_list = [] 158 queue = collections.deque() 159 queue.extend(to_ops) 160 while queue: 161 op = queue.popleft() 162 # We are interested in this op. 163 if op in reached_ops: 164 between_ops.add(op) 165 between_op_list.append(op) 166 # Clear the boolean so we won't add the inputs again. 167 reached_ops.remove(op) 168 for inp in _NonEagerInputs(op, xs): 169 queue.append(inp.op) 170 # X in between_ops iff X is on a path of zero or more backpropagatable tensors 171 # between from_ops and to_ops 172 173 # 'loop_state' is None if there are no while loops. 174 loop_state = control_flow_ops.MaybeCreateControlFlowState( 175 between_op_list, between_ops, colocate_gradients_with_ops) 176 177 # Initialize pending count for between ops. 178 pending_count = collections.defaultdict(int) 179 for op in between_op_list: 180 for x in _NonEagerInputs(op, xs): 181 if x.op in between_ops: 182 pending_count[x.op] += 1 183 184 return reachable_to_ops, pending_count, loop_state 185 186 187def _AsList(x): 188 return x if isinstance(x, (list, tuple)) else [x] 189 190 191def _DefaultGradYs(grad_ys, 192 ys, 193 colocate_gradients_with_ops, 194 gradient_uid="__unsupported__"): 195 """Fill in default values for grad_ys. 196 197 Args: 198 grad_ys: List of gradients, can contain None. 199 ys: List of tensors. 200 colocate_gradients_with_ops: If True, try colocating gradients with 201 the corresponding op. 202 gradient_uid: A unique identifier within the graph indicating 203 which invocation of gradients is being executed. Used to cluster 204 ops for compilation. 205 206 Returns: 207 A list of gradients to use, without None. 208 209 Raises: 210 ValueError: If sizes of gradients and inputs don't match 211 TypeError: If type of any gradient is not valid for its input. 212 """ 213 if len(grad_ys) != len(ys): 214 raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys))) 215 grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y") 216 new_grad_ys = [] 217 for i in xrange(len(grad_ys)): 218 grad_y = grad_ys[i] 219 y = ys[i] 220 with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops): 221 if grad_y is None: 222 if y.dtype.is_complex: 223 raise TypeError( 224 "Gradients of complex tensors must set grad_ys (y.dtype = %r)" % 225 y.dtype) 226 new_grad_ys.append( 227 array_ops.fill( 228 array_ops.shape(y), 229 constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i))) 230 continue 231 if y.dtype.is_floating or y.dtype.is_integer: 232 if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer: 233 raise TypeError( 234 "Gradient type %s generated for real or " 235 "integer-valued tensor %s with type %s must be " 236 "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y, 237 dtypes.as_dtype(y.dtype).name)) 238 elif y.dtype.is_complex: 239 if not grad_y.dtype.is_complex: 240 raise TypeError( 241 "Gradient type %s generated for complex-valued " 242 "tensor %s with type %s must be real" % (dtypes.as_dtype( 243 grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) 244 elif y.dtype == dtypes.variant: 245 if grad_y.dtype != dtypes.variant: 246 raise TypeError( 247 "Gradient type %s generated for variant " 248 "tensor %s with type %s must be variant" % (dtypes.as_dtype( 249 grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) 250 elif y.dtype == dtypes.resource: 251 # We assume y is the handle of a ResourceVariable. The gradient of a 252 # ResourceVariable should be a numeric value, not another resource. 253 if grad_y.dtype == dtypes.resource: 254 raise TypeError("Input gradient %s for resource tensor %s should not " 255 "be a resource" % (grad_y, y)) 256 else: 257 raise TypeError( 258 "Tensor %s with type %s must be numeric " 259 "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name)) 260 # Create a grad_y tensor in the name scope of the gradient. 261 # Required for TensorArrays to identify which gradient call a 262 # grad_y value is coming from. 263 if isinstance(grad_y, ops.IndexedSlices): 264 new_grad_ys.append( 265 ops.IndexedSlices( 266 indices=(array_ops.identity( 267 grad_y.indices, name="grad_ys_%d_indices" % i) 268 if isinstance(grad_y.indices, ops.Tensor) else 269 grad_y.indices), 270 values=(array_ops.identity( 271 grad_y.values, name="grad_ys_%d_values" % i) if isinstance( 272 grad_y.values, ops.Tensor) else grad_y.values), 273 dense_shape=(array_ops.identity( 274 grad_y.dense_shape, name="grad_ys_%d_shape" % i) 275 if isinstance(grad_y.dense_shape, ops.Tensor) else 276 grad_y.dense_shape))) 277 else: 278 new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i)) 279 280 return new_grad_ys 281 282 283def IsTrainable(tensor_or_dtype): 284 if isinstance(tensor_or_dtype, ops.Tensor): 285 dtype = tensor_or_dtype.dtype 286 else: 287 dtype = tensor_or_dtype 288 dtype = dtypes.as_dtype(dtype) 289 return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64, 290 dtypes.complex64, dtypes.complex128, 291 dtypes.resource, dtypes.variant) 292 293 294def _IsBackpropagatable(tensor): 295 if IsTrainable(tensor): 296 return True 297 dtype = dtypes.as_dtype(tensor.dtype) 298 return dtype.base_dtype == dtypes.bfloat16 299 300 301def _VerifyGeneratedGradients(grads, op): 302 """Verify that gradients are valid in number and type. 303 304 Args: 305 grads: List of generated gradients. 306 op: Operation for which the gradients where generated. 307 308 Raises: 309 ValueError: if sizes of gradients and inputs don't match. 310 TypeError: if type of any gradient is not valid for its input. 311 """ 312 # While ops have inputs added to them during the gradient computation, so we 313 # skip the below check. See while_v2 for details. 314 if op.type == "While": return 315 316 if len(grads) != len(op.inputs): 317 raise ValueError("Num gradients %d generated for op %s do not match num " 318 "inputs %d" % (len(grads), op.node_def, len(op.inputs))) 319 320 321def _StopOps(from_ops, stop_gradient_ops, pending_count, xs): 322 """The set of ops that terminate the gradient computation. 323 324 This computes the frontier of the forward graph *before* which backprop 325 should stop. Operations in the returned set will not be differentiated. 326 This set is defined as the subset of `from_ops` containing ops that have 327 no predecessor in `from_ops`. `pending_count` is the result of 328 `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops` 329 iff pending_count[op] > 0. 330 331 In addition, none of `stop_gradient_ops` will be differentiated. 332 333 Args: 334 from_ops: list of Operations. 335 stop_gradient_ops: list of Operations never to backprop through. 336 pending_count: mapping from operation to number of backprop inputs. 337 xs: list of Tensors. 338 339 Returns: 340 The set of operations. 341 """ 342 stop_ops = set() 343 for op in from_ops: 344 is_stop_op = True 345 for inp in _NonEagerInputs(op, xs): 346 if pending_count[inp.op] > 0: 347 is_stop_op = False 348 break 349 if is_stop_op: 350 stop_ops.add(op) 351 stop_ops.update(op for op in stop_gradient_ops) 352 return stop_ops 353 354 355@contextlib.contextmanager 356def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pylint: disable=invalid-name 357 """Context to colocate with `op` if `colocate_gradients_with_ops`.""" 358 if colocate_gradients_with_ops: 359 with ops._colocate_with_for_gradient(op, gradient_uid): # pylint: disable=protected-access 360 yield 361 else: 362 yield 363 364 365def _IsPartitionedCall(op): 366 return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall" 367 368 369def _SymGrad(op, out_grads): 370 """Backprop through a function call node op given its outputs' gradients.""" 371 f_in = [x for x in op.inputs] + out_grads 372 f_types = [x.dtype for x in op.inputs] 373 f = attr_value_pb2.NameAttrList() 374 if _IsPartitionedCall(op): 375 f.name = op.get_attr("f").name 376 else: 377 f.name = op.type 378 for k in op.node_def.attr: 379 f.attr[k].CopyFrom(op.node_def.attr[k]) 380 # TODO(apassos) use a better dtype here 381 in_grads = functional_ops.symbolic_gradient( 382 input=f_in, 383 Tout=[x if x != dtypes.resource else dtypes.float32 for x in f_types], 384 f=f) 385 return in_grads 386 387 388def _MaybeCompile(scope, op, func, grad_fn): 389 """Compile the calculation in grad_fn if op was marked as compiled.""" 390 scope = scope.rstrip("/").replace("/", "_") 391 if func is not None: 392 xla_compile = func.definition.attr["_XlaCompile"].b 393 xla_separate_compiled_gradients = func.definition.attr[ 394 "_XlaSeparateCompiledGradients"].b 395 xla_scope = func.definition.attr["_XlaScope"].s.decode() 396 else: 397 try: 398 xla_compile = op.get_attr("_XlaCompile") 399 xla_separate_compiled_gradients = op.get_attr( 400 "_XlaSeparateCompiledGradients") 401 xla_scope = op.get_attr("_XlaScope").decode() 402 except ValueError: 403 return grad_fn() # Exit early 404 405 if not xla_compile: 406 return grad_fn() # Exit early 407 408 # If the gradients are supposed to be compiled separately, we give them a 409 # _XlaScope name that is based on the name_scope of the gradients. Otherwise 410 # they just inherit the existing _XlaScope name, which lets them be merged 411 # together with the non-gradient computation. 412 if xla_separate_compiled_gradients: 413 xla_grad_scope = "%s_grad_%s" % (xla_scope, scope) 414 else: 415 xla_grad_scope = xla_scope 416 417 attrs = { 418 "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile), 419 "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode()) 420 } 421 with ops.get_default_graph()._attr_scope(attrs): # pylint: disable=protected-access 422 return grad_fn() 423 424 425def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs): 426 """Raises an error if we backprop through a loop var.""" 427 # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error 428 # message. 429 target_op = None 430 queue = collections.deque([op]) 431 visited = set() 432 while queue: 433 curr_op = queue.popleft() 434 if curr_op in visited: continue 435 visited.add(curr_op) 436 if curr_op in from_ops: 437 target_op = curr_op 438 break 439 queue.extend(t.op for t in _NonEagerInputs(curr_op, xs)) 440 assert target_op 441 raise ValueError( 442 "Cannot compute gradient inside while loop with respect to op '%s'. " 443 "We do not support taking the gradient wrt or through the initial value " 444 "of a loop variable. Gradients can be computed through loop invariants " 445 "or wrt the input parameters to the loop body." 446 % target_op.name) 447 448 449def _IsFunction(graph): 450 return (isinstance(graph, FuncGraph) or 451 isinstance(graph, framework_function._FuncGraph)) # pylint: disable=protected-access 452 453 454def _Captures(func_graph): 455 if isinstance(func_graph, FuncGraph): 456 return func_graph.captures 457 else: 458 assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access 459 return func_graph._captured # pylint: disable=protected-access 460 461 462def _MaybeCaptured(t): 463 """If t is a captured value placeholder, returns the original captured value. 464 465 Args: 466 t: Tensor 467 468 Returns: 469 A tensor, potentially from a different Graph/FuncGraph. 470 """ 471 # pylint: disable=protected-access 472 if (not isinstance(t, ops.EagerTensor) and 473 _IsFunction(t.op.graph) and t.op.type == "Placeholder"): 474 for input_t, placeholder_t in _Captures(t.op.graph).items(): 475 if t == placeholder_t: 476 return _MaybeCaptured(input_t) 477 # pylint: enable=protected-access 478 return t 479 480 481def _NonEagerInputs(op, xs): 482 """Returns the inputs of op, crossing closure boundaries where necessary. 483 484 Does not return any captured EagerTensors, i.e., the number of tensors 485 returned may be less than than the actual number of inputs. 486 487 Args: 488 op: Operation 489 xs: list of Tensors we are differentiating w.r.t. 490 491 Returns: 492 A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op 493 is in a FuncGraph and has captured inputs. 494 """ 495 return [t for t in _Inputs(op, xs) if not isinstance(t, ops.EagerTensor)] 496 497 498# TODO(skyewm): plumbing xs through everywhere is ugly, consider making 499# _GradientsHelper a class with xs as a member variable. 500def _Inputs(op, xs): 501 """Returns the inputs of op, crossing closure boundaries where necessary. 502 503 Args: 504 op: Operation 505 xs: list of Tensors we are differentiating w.r.t. 506 507 Returns: 508 A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op 509 is in a FuncGraph and has captured inputs. 510 """ 511 if _IsFunction(op.graph): # pylint: disable=protected-access 512 inputs = [] 513 for t in op.inputs: 514 # If we're differentiating w.r.t. `t`, do not attempt to traverse through 515 # it to a captured value. The algorithm needs to "see" `t` in this case, 516 # even if it's a function input for a captured value, whereas usually we'd 517 # like to traverse through these closures as if the captured value was the 518 # direct input to op. 519 if t not in xs: 520 t = _MaybeCaptured(t) 521 inputs.append(t) 522 return inputs 523 else: 524 return op.inputs 525 526 527def _Consumers(t, func_graphs): 528 """Returns the consumers of t, crossing closure boundaries where necessary. 529 530 Args: 531 t: Tensor 532 func_graphs: a list of FuncGraphs that may have captured t. 533 534 Returns: 535 A list of tensors. The tensors will be from the current graph and/or 536 func_graphs. 537 """ 538 consumers = t.consumers() 539 for func in func_graphs: 540 for input_t, placeholder in _Captures(func).items(): 541 if input_t == t: 542 consumers.extend(_Consumers(placeholder, func_graphs)) 543 return consumers 544 545 546def _GradientsHelper(ys, 547 xs, 548 grad_ys=None, 549 name="gradients", 550 colocate_gradients_with_ops=False, 551 gate_gradients=False, 552 aggregation_method=None, 553 stop_gradients=None, 554 unconnected_gradients=UnconnectedGradients.NONE, 555 src_graph=None): 556 """Implementation of gradients().""" 557 if context.executing_eagerly(): 558 raise RuntimeError("tf.gradients is not supported when eager execution " 559 "is enabled. Use tf.GradientTape instead.") 560 if src_graph is None: 561 src_graph = ops.get_default_graph() 562 try: 563 unconnected_gradients = UnconnectedGradients(unconnected_gradients) 564 except ValueError: 565 raise ValueError( 566 "Unknown value for unconnected_gradients: %r" % unconnected_gradients) 567 568 # If src_graph is a _FuncGraph (i.e. a function body), gather it and all 569 # ancestor graphs. This is necessary for correctly handling captured values. 570 func_graphs = [] 571 curr_graph = src_graph 572 while _IsFunction(curr_graph): 573 func_graphs.append(curr_graph) 574 if isinstance(curr_graph, FuncGraph): 575 curr_graph = curr_graph.outer_graph 576 else: 577 assert isinstance(curr_graph, framework_function._FuncGraph) # pylint: disable=protected-access 578 curr_graph = curr_graph._outer_graph # pylint: disable=protected-access 579 580 ys = _AsList(ys) 581 xs = _AsList(xs) 582 stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients) 583 if grad_ys is None: 584 grad_ys = [None] * len(ys) 585 else: 586 grad_ys = _AsList(grad_ys) 587 588 with ops.name_scope( 589 name, "gradients", 590 list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope: 591 # Get a uid for this call to gradients that can be used to help 592 # cluster ops for compilation. 593 gradient_uid = ops.get_default_graph().unique_name("uid") 594 ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y") 595 xs = [ 596 x.handle if resource_variable_ops.is_resource_variable(x) else x 597 for x in xs 598 ] 599 xs = ops.internal_convert_n_to_tensor_or_indexed_slices( 600 xs, name="x", as_ref=True) 601 grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, 602 gradient_uid) 603 604 # The approach we take here is as follows: Create a list of all ops in the 605 # subgraph between the ys and xs. Visit these ops in reverse order of ids 606 # to ensure that when we visit an op the gradients w.r.t its outputs have 607 # been collected. Then aggregate these gradients if needed, call the op's 608 # gradient function, and add the generated gradients to the gradients for 609 # its input. 610 611 # Initialize the pending count for ops in the connected subgraph from ys 612 # to the xs. 613 to_ops = [t.op for t in ys] 614 from_ops = [t.op for t in xs] 615 stop_gradient_ops = [t.op for t in stop_gradients] 616 reachable_to_ops, pending_count, loop_state = _PendingCount( 617 to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs) 618 619 # Iterate over the collected ops. 620 # 621 # grads: op => list of gradients received on each output endpoint of the 622 # op. The gradients for each endpoint are initially collected as a list. 623 # When it is time to call the op's gradient function, for each endpoint we 624 # aggregate the list of received gradients into a Add() Operation if there 625 # is more than one. 626 grads = {} 627 628 # Add the initial gradients for the ys. 629 for y, grad_y in zip(ys, grad_ys): 630 _SetGrad(grads, y, grad_y) 631 632 # Initialize queue with to_ops. 633 queue = collections.deque() 634 # Add the ops in 'to_ops' into the queue. 635 to_ops_set = set() 636 for op in to_ops: 637 # 'ready' handles the case where one output gradient relies on 638 # another output's gradient. 639 ready = (pending_count[op] == 0) 640 if ready and op not in to_ops_set and op in reachable_to_ops: 641 to_ops_set.add(op) 642 queue.append(op) 643 644 if loop_state: 645 loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set) 646 for y in loop_exits: 647 if IsTrainable(y): 648 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) 649 queue.append(y.op) 650 651 stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs) 652 while queue: 653 # generate gradient subgraph for op. 654 op = queue.popleft() 655 with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): 656 if loop_state: 657 loop_state.EnterGradWhileContext(op, before=True) 658 out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, 659 aggregation_method) 660 if loop_state: 661 loop_state.ExitGradWhileContext(op, before=True) 662 663 grad_fn = None 664 func_call = None 665 is_partitioned_call = _IsPartitionedCall(op) 666 # pylint: disable=protected-access 667 is_func_call = ( 668 src_graph._is_function(op.type) or is_partitioned_call) 669 # pylint: enable=protected-access 670 has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) 671 if has_out_grads and (op not in stop_ops): 672 try: 673 grad_fn = ops.get_gradient_function(op) 674 except LookupError: 675 if is_func_call: 676 if is_partitioned_call: 677 func_call = src_graph._get_function( # pylint: disable=protected-access 678 compat.as_bytes(op.get_attr("f").name)) 679 else: 680 func_call = src_graph._get_function(op.type) # pylint: disable=protected-access 681 # Note that __defun is not set if the graph is 682 # imported. If it's set, we prefer to access the original 683 # defun. 684 func_call = getattr(op, "__defun", func_call) 685 grad_fn = func_call.python_grad_func 686 else: 687 raise LookupError( 688 "No gradient defined for operation '%s' (op type: %s)" % 689 (op.name, op.type)) 690 if loop_state: 691 loop_state.EnterGradWhileContext(op, before=False) 692 693 # NOTE(skyewm): We don't support computing gradients wrt a loop variable 694 # unless it's within the context of a single iteration (i.e. the 695 # gradient is wrt to the loop parameter in the body function, not wrt or 696 # through the initial value). This means if we're in a while loop 697 # context, we should never see a switch node from this context. 698 # pylint: disable=protected-access 699 if (control_flow_util.IsSwitch(op) and 700 op._control_flow_context is not None and 701 op._control_flow_context.IsWhileContext() and 702 op._control_flow_context == 703 ops.get_default_graph()._get_control_flow_context()): 704 _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs) 705 # pylint: enable=protected-access 706 707 if (grad_fn or is_func_call) and has_out_grads: 708 # NOTE: If _AggregatedGrads didn't compute a value for the i'th 709 # output, it means that the cost does not depend on output[i], 710 # therefore dC/doutput[i] is 0. 711 for i, out_grad in enumerate(out_grads): 712 if (not isinstance(out_grad, ops.Tensor) and not out_grad) and ( 713 (not grad_fn and is_func_call) or IsTrainable(op.outputs[i])): 714 # Only trainable outputs or outputs for a function call that 715 # will use SymbolicGradient get a zero gradient. Gradient 716 # functions should ignore the gradient for other outputs. 717 # TODO(apassos) gradients of resource handles might be an 718 # issue here because of zeros. 719 if loop_state: 720 out_grads[i] = loop_state.ZerosLike(op, i) 721 else: 722 out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i) 723 with ops.name_scope(op.name + "_grad"): 724 # pylint: disable=protected-access 725 with src_graph._original_op(op): 726 # pylint: enable=protected-access 727 if grad_fn: 728 # If grad_fn was found, do not use SymbolicGradient even for 729 # functions. 730 in_grads = _MaybeCompile(grad_scope, op, func_call, 731 lambda: grad_fn(op, *out_grads)) 732 else: 733 # For function call ops, we add a 'SymbolicGradient' 734 # node to the graph to compute gradients. 735 in_grads = _MaybeCompile(grad_scope, op, func_call, 736 lambda: _SymGrad(op, out_grads)) 737 in_grads = _AsList(in_grads) 738 _VerifyGeneratedGradients(in_grads, op) 739 if gate_gradients and len([x for x in in_grads 740 if x is not None]) > 1: 741 with ops.device(None): 742 with ops._colocate_with_for_gradient( # pylint: disable=protected-access 743 None, 744 gradient_uid, 745 ignore_existing=True): 746 in_grads = control_flow_ops.tuple(in_grads) 747 _LogOpGradients(op, out_grads, in_grads) 748 else: 749 # If no grad_fn is defined or none of out_grads is available, 750 # just propagate a list of None backwards. 751 in_grads = [None] * len(_Inputs(op, xs)) 752 # Note: we don't filter out eager inputs here because the inputs need to 753 # line up with in_grads. 754 for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)): 755 if in_grad is not None: 756 if (isinstance(in_grad, ops.Tensor) and 757 t_in.dtype != dtypes.resource): 758 try: 759 in_grad.set_shape(t_in.get_shape()) 760 except ValueError: 761 raise ValueError( 762 "Incompatible shapes between op input and calculated " 763 "input gradient. Forward operation: %s. Input index: %d. " 764 "Original input shape: %s. " 765 "Calculated input gradient shape: %s" % 766 (op.name, i, t_in.shape, in_grad.shape)) 767 if not isinstance(t_in, ops.EagerTensor): 768 _SetGrad(grads, t_in, in_grad) 769 if loop_state: 770 loop_state.ExitGradWhileContext(op, before=False) 771 772 # Update pending count for the inputs of op and enqueue ready ops. 773 _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, 774 xs) 775 776 if loop_state: 777 loop_state.PostProcessing() 778 return [_GetGrad(grads, x, unconnected_gradients) for x in xs] 779 780 781def _HasAnyNotNoneGrads(grads, op): 782 """Return true iff op has real gradient.""" 783 out_grads = _GetGrads(grads, op) 784 for out_grad in out_grads: 785 if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)): 786 return True 787 if out_grad and isinstance(out_grad, collections.Sequence): 788 if any(g is not None for g in out_grad): 789 return True 790 return False 791 792 793def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, 794 xs): 795 """Update pending count for the inputs of op and enqueue ready ops.""" 796 for x in _NonEagerInputs(op, xs): 797 pending_count[x.op] -= 1 798 ready = (pending_count[x.op] == 0) 799 if loop_state and not ready: 800 ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op) 801 if ready: 802 if control_flow_util.IsLoopExit(x.op): 803 # if x is an exit without real gradient, defer processing them. 804 grad_state = loop_state.GetGradState(x.op, before=False) 805 grad_state.deferred_exits.append(x) 806 grad_state.pending_exits_count -= 1 807 if grad_state.pending_exits_count == 0: 808 # We now have all the exits so process them. 809 has_not_none_grad = False 810 for y in grad_state.deferred_exits: 811 if _HasAnyNotNoneGrads(grads, y.op): 812 has_not_none_grad = True 813 queue.append(y.op) 814 else: 815 grad_state.unused_exits.append(y) 816 if has_not_none_grad: 817 # For an unused exit, if it has trainable outputs, backprop 818 # a zero gradient. Otherwise, just ignore it. 819 for y in grad_state.unused_exits: 820 if IsTrainable(y): 821 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) 822 queue.append(y.op) 823 else: 824 # All exits are "unused" so use None as gradient. 825 for y in grad_state.unused_exits: 826 queue.append(y.op) 827 else: 828 queue.append(x.op) 829 830 831def _SetGrad(grads, t, grad): 832 """Sets gradient "grad" in "grads" for tensor "t".""" 833 op = t.op 834 op_grads = grads.get(op) 835 if not op_grads: 836 op_grads = [[] for _ in xrange(len(op.outputs))] 837 grads[op] = op_grads 838 t_grads = op_grads[t.value_index] 839 if isinstance(t_grads, list): 840 t_grads.append(grad) 841 else: 842 assert control_flow_util.IsLoopSwitch(op) 843 op_grads[t.value_index] = grad 844 845 846def _GetGrad(grads, t, unconnected_gradients): 847 """Gets gradient for tensor "t".""" 848 op = t.op 849 op_grads = grads.get(op) 850 if not op_grads: 851 if unconnected_gradients == UnconnectedGradients.ZERO: 852 t_dtype = t.dtype if t.dtype != dtypes.resource else dtypes.float32 853 return array_ops.zeros_like(t, dtype=t_dtype) 854 elif unconnected_gradients == UnconnectedGradients.NONE: 855 return None 856 else: 857 raise ValueError( 858 "Unknown value for unconnected_gradients: %r" % unconnected_gradients) 859 860 t_grad = op_grads[t.value_index] 861 assert not isinstance( 862 t_grad, list), ("gradients list should have been aggregated by now.") 863 return t_grad 864 865 866def _GetGrads(grads, op): 867 """Gets all gradients for op.""" 868 if op in grads: 869 return grads[op] 870 else: 871 return [[] for _ in xrange(len(op.outputs))] 872 873 874def _HandleNestedIndexedSlices(grad): 875 assert isinstance(grad, ops.IndexedSlices) 876 if isinstance(grad.values, ops.Tensor): 877 return grad 878 else: 879 assert isinstance(grad.values, ops.IndexedSlices) 880 g = _HandleNestedIndexedSlices(grad.values) 881 return ops.IndexedSlices(g.values, array_ops.gather( 882 grad.indices, g.indices), g.dense_shape) 883 884 885def _AccumulatorShape(inputs): 886 shape = tensor_shape.unknown_shape() 887 for i in inputs: 888 if isinstance(i, ops.Tensor): 889 shape = shape.merge_with(i.get_shape()) 890 return shape 891 892 893def _LogOpGradients(op, out_grads, in_grads): 894 """Log the in and out grads of an op.""" 895 logging.vlog(1, "Gradient for '" + op.name + "'") 896 897 def _FilterGrad(x): 898 if x is None: 899 return False 900 if isinstance(x, (list, tuple)): 901 return bool(x) 902 else: 903 return True 904 905 logging.vlog(1, " in --> %s", 906 ", ".join([x.name for x in out_grads if _FilterGrad(x)])) 907 logging.vlog(1, " out --> %s", 908 ", ".join([x.name for x in in_grads if _FilterGrad(x)])) 909 910 911def _MultiDeviceAddN(tensor_list, gradient_uid): 912 """Adds tensors from potentially multiple devices.""" 913 # Basic function structure comes from control_flow_ops.group(). 914 # Sort tensors according to their devices. 915 tensors_on_device = collections.defaultdict(lambda: []) 916 for tensor in tensor_list: 917 tensors_on_device[tensor.device].append(tensor) 918 919 # For each device, add the tensors on that device first. 920 # Then gather the partial sums from multiple devices. 921 # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion. 922 # E.g., aggregate per GPU, then per task, and so on. 923 summands = [] 924 925 def DeviceKey(dev): 926 return "" if dev is None else dev 927 928 for dev in sorted(tensors_on_device, key=DeviceKey): 929 tensors = tensors_on_device[dev] 930 with ops._colocate_with_for_gradient( # pylint: disable=protected-access 931 tensors[0].op, 932 gradient_uid, 933 ignore_existing=True): 934 summands.append(math_ops.add_n(tensors)) 935 936 return math_ops.add_n(summands) 937 938 939@tf_export("AggregationMethod") 940class AggregationMethod(object): 941 """A class listing aggregation methods used to combine gradients. 942 943 Computing partial derivatives can require aggregating gradient 944 contributions. This class lists the various methods that can 945 be used to combine gradients in the graph: 946 947 * `ADD_N`: All of the gradient terms are summed as part of one 948 operation using the "AddN" op. It has the property that all 949 gradients must be ready before any aggregation is performed. 950 * `DEFAULT`: The system-chosen default aggregation method. 951 """ 952 ADD_N = 0 953 DEFAULT = ADD_N 954 # The following are experimental and may not be supported in future releases. 955 EXPERIMENTAL_TREE = 1 956 EXPERIMENTAL_ACCUMULATE_N = 2 957 958 959def _AggregatedGrads(grads, 960 op, 961 gradient_uid, 962 loop_state, 963 aggregation_method=None): 964 """Get the aggregated gradients for op. 965 966 Args: 967 grads: The map of memoized gradients. 968 op: The op to get gradients for. 969 gradient_uid: A unique identifier within the graph indicating 970 which invocation of gradients is being executed. Used to cluster 971 ops for compilation. 972 loop_state: An object for maintaining the state of the while loops in the 973 graph. It is of type ControlFlowState. None if the graph 974 contains no while loops. 975 aggregation_method: Specifies the method used to combine gradient terms. 976 Accepted values are constants defined in the class `AggregationMethod`. 977 978 Returns: 979 A list of gradients, one per each output of `op`. If the gradients 980 for a particular output is a list, this function aggregates it 981 before returning. 982 983 Raises: 984 TypeError: if the incoming grads are not Tensors or IndexedSlices. 985 ValueError: if the arguments are invalid. 986 987 """ 988 if aggregation_method is None: 989 aggregation_method = AggregationMethod.DEFAULT 990 if aggregation_method not in [ 991 AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE, 992 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 993 ]: 994 raise ValueError( 995 "Invalid aggregation_method specified %s." % aggregation_method) 996 out_grads = _GetGrads(grads, op) 997 for i, out_grad in enumerate(out_grads): 998 if loop_state: 999 if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)): 1000 assert control_flow_util.IsLoopSwitch(op) 1001 continue 1002 # Grads have to be Tensors or IndexedSlices 1003 if (isinstance(out_grad, collections.Sequence) and not all( 1004 isinstance(g, (ops.Tensor, ops.IndexedSlices)) 1005 for g in out_grad 1006 if g is not None 1007 )): 1008 raise TypeError("gradients have to be either all Tensors " 1009 "or all IndexedSlices") 1010 # Aggregate multiple gradients, and convert [] to None. 1011 if out_grad: 1012 if len(out_grad) < 2: 1013 used = "nop" 1014 out_grads[i] = out_grad[0] 1015 elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None): 1016 tensor_shape = _AccumulatorShape(out_grad) 1017 if (aggregation_method == AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 1018 and len(out_grad) > 2 and tensor_shape.is_fully_defined()): 1019 # The benefit of using AccumulateN is that its inputs can be combined 1020 # in any order and this can allow the expression to be evaluated with 1021 # a smaller memory footprint. When used with gpu_allocator_retry, 1022 # it is possible to compute a sum of terms which are much larger than 1023 # total GPU memory. 1024 # AccumulateN can currently only be used if we know the shape for 1025 # an accumulator variable. If this is not known, or if we only have 1026 # 2 grads then we fall through to the "tree" case below. 1027 used = "accumulate_n" 1028 out_grads[i] = math_ops.accumulate_n(out_grad) 1029 elif aggregation_method in [ 1030 AggregationMethod.EXPERIMENTAL_TREE, 1031 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 1032 ]: 1033 # Aggregate all gradients by doing pairwise sums: this may 1034 # reduce performance, but it can improve memory because the 1035 # gradients can be released earlier. 1036 # 1037 # TODO(vrv): Consider replacing this with a version of 1038 # tf.AddN() that eagerly frees its inputs as soon as they are 1039 # ready, so the order of this tree does not become a problem. 1040 used = "tree" 1041 with ops.name_scope(op.name + "_gradient_sum"): 1042 running_sum = out_grad[0] 1043 for grad in out_grad[1:]: 1044 running_sum = math_ops.add_n([running_sum, grad]) 1045 out_grads[i] = running_sum 1046 else: 1047 used = "add_n" 1048 out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid) 1049 logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad), 1050 tensor_shape, used) 1051 else: 1052 out_grads[i] = _AggregateIndexedSlicesGradients(out_grad) 1053 else: # not out_grad 1054 # out_grads[i] is [], thus its aggregation is simply None. 1055 out_grads[i] = None 1056 return out_grads 1057 1058 1059def _AggregateIndexedSlicesGradients(grads): 1060 """Aggregates gradients of type `IndexedSlices` by concatenation.""" 1061 if len(grads) < 1: 1062 return None 1063 elif len(grads) == 1: 1064 return grads[0] 1065 else: 1066 grads = math_ops._as_indexed_slices_list( # pylint: disable=protected-access 1067 [g for g in grads if g is not None]) 1068 grads = [_HandleNestedIndexedSlices(x) for x in grads] # pylint: disable=protected-access 1069 # Form IndexedSlices out of the concatenated values and indices. 1070 concat_grad = ops.IndexedSlices( 1071 array_ops.concat([x.values for x in grads], axis=0), 1072 array_ops.concat([x.indices for x in grads], axis=0), 1073 grads[0].dense_shape) 1074 1075 return concat_grad 1076