1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implements the graph generation for computation of gradients.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import warnings 24 25import numpy as np 26import six 27from six.moves import xrange # pylint: disable=redefined-builtin 28 29from tensorflow.core.framework import attr_value_pb2 30from tensorflow.python.eager import context 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.ops import array_grad # pylint: disable=unused-import 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import check_ops # pylint: disable=unused-import 39from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import 40from tensorflow.python.ops import control_flow_ops 41from tensorflow.python.ops import control_flow_util 42from tensorflow.python.ops import functional_ops 43from tensorflow.python.ops import image_grad # pylint: disable=unused-import 44from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import 45from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import 46from tensorflow.python.ops import logging_ops # pylint: disable=unused-import 47from tensorflow.python.ops import manip_grad # pylint: disable=unused-import 48from tensorflow.python.ops import math_grad # pylint: disable=unused-import 49from tensorflow.python.ops import math_ops 50from tensorflow.python.ops import resource_variable_ops 51from tensorflow.python.ops import spectral_grad # pylint: disable=unused-import 52from tensorflow.python.ops import tensor_array_ops 53from tensorflow.python.platform import tf_logging as logging 54from tensorflow.python.util.tf_export import tf_export 55 56# Warn the user if we convert a sparse representation to dense with at 57# least this number of elements. 58_LARGE_SPARSE_NUM_ELEMENTS = 100000000 59 60 61def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False): 62 """Converts an IndexedSlices object `value` to a Tensor. 63 64 NOTE(mrry): This function is potentially expensive. 65 66 Args: 67 value: An ops.IndexedSlices object. 68 dtype: The dtype of the Tensor to be returned. 69 name: Optional name to use for the returned Tensor. 70 as_ref: True if a ref is requested. 71 72 Returns: 73 A dense Tensor representing the values in the given IndexedSlices. 74 75 Raises: 76 ValueError: If the IndexedSlices does not have the same dtype. 77 """ 78 _ = as_ref 79 if dtype and not dtype.is_compatible_with(value.dtype): 80 raise ValueError( 81 "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" % 82 (dtype.name, value.dtype.name)) 83 if value.dense_shape is None: 84 raise ValueError( 85 "Tensor conversion requested for IndexedSlices without dense_shape: %s" 86 % str(value)) 87 # TODO(mrry): Consider adding static shape information to 88 # IndexedSlices, to avoid using numpy here. 89 dense_shape_value = tensor_util.constant_value(value.dense_shape) 90 if dense_shape_value is not None: 91 num_elements = np.prod(dense_shape_value) 92 if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS: 93 warnings.warn( 94 "Converting sparse IndexedSlices to a dense Tensor with %d elements. " 95 "This may consume a large amount of memory." % num_elements) 96 else: 97 warnings.warn( 98 "Converting sparse IndexedSlices to a dense Tensor of unknown shape. " 99 "This may consume a large amount of memory.") 100 return math_ops.unsorted_segment_sum( 101 value.values, value.indices, value.dense_shape[0], name=name) 102 103 104ops.register_tensor_conversion_function(ops.IndexedSlices, 105 _IndexedSlicesToTensor) 106 107 108def _MarkReachedOps(from_ops, reached_ops): 109 """Mark all ops reached from "from_ops". 110 111 Args: 112 from_ops: list of Operations. 113 reached_ops: list of booleans, indexed by operation id. 114 """ 115 queue = collections.deque() 116 queue.extend(from_ops) 117 while queue: 118 op = queue.popleft() 119 if not reached_ops[op._id]: 120 reached_ops[op._id] = True 121 for output in op.outputs: 122 queue.extend(output.consumers()) 123 124 125def _GatherInputs(to_ops, reached_ops): 126 """List all inputs of to_ops that are in reached_ops. 127 128 Args: 129 to_ops: list of Operations. 130 reached_ops: list of booleans, indexed by operation id. 131 132 Returns: 133 The list of all inputs of to_ops that are in reached_ops. 134 That list includes all elements of to_ops. 135 """ 136 inputs = [] 137 queue = collections.deque() 138 queue.extend(to_ops) 139 while queue: 140 op = queue.popleft() 141 # We are interested in this op. 142 if reached_ops[op._id]: 143 inputs.append(op) 144 # Clear the boolean so we won't add the inputs again. 145 reached_ops[op._id] = False 146 for inp in op.inputs: 147 queue.append(inp.op) 148 return inputs 149 150 151def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops): 152 """Initialize the pending count for ops between two lists of Operations. 153 154 'pending_count[op._id]' indicates the number of backprop inputs 155 to this operation. 156 157 Args: 158 graph: a Graph. 159 to_ops: list of Operations. 160 from_ops: list of Operations. 161 colocate_gradients_with_ops: Python bool. See docstring of gradients(). 162 163 Returns: 164 A tuple containing: (1) a list of integers indexed by operation id, 165 indicating the number of backprop inputs to this operation, and (2) 166 a ControlFlowState object which is not None if the ops between from_ops 167 and to_ops contain control flow loops. 168 """ 169 # Mark reachable ops from from_ops. 170 reached_ops = [False] * (graph._last_id + 1) 171 for op in to_ops: 172 reached_ops[op._id] = True 173 _MarkReachedOps(from_ops, reached_ops) 174 175 # Mark between ops. 176 between_ops = [False] * (graph._last_id + 1) 177 between_op_list = [] 178 queue = collections.deque() 179 queue.extend(to_ops) 180 while queue: 181 op = queue.popleft() 182 # We are interested in this op. 183 if reached_ops[op._id]: 184 between_ops[op._id] = True 185 between_op_list.append(op) 186 # Clear the boolean so we won't add the inputs again. 187 reached_ops[op._id] = False 188 for inp in op.inputs: 189 queue.append(inp.op) 190 191 # 'loop_state' is None if there are no while loops. 192 loop_state = control_flow_ops.MaybeCreateControlFlowState( 193 between_op_list, between_ops, colocate_gradients_with_ops) 194 195 # Initialize pending count for between ops. 196 pending_count = [0] * (graph._last_id + 1) 197 for op in between_op_list: 198 for x in op.inputs: 199 if between_ops[x.op._id]: 200 pending_count[x.op._id] += 1 201 202 return pending_count, loop_state 203 204 205def _AsList(x): 206 return x if isinstance(x, (list, tuple)) else [x] 207 208 209def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops): 210 """Fill in default values for grad_ys. 211 212 Args: 213 grad_ys: List of gradients, can contain None. 214 ys: List of tensors. 215 colocate_gradients_with_ops: If True, try colocating gradients with 216 the corresponding op. 217 218 Returns: 219 A list of gradients to use, without None. 220 221 Raises: 222 ValueError: If sizes of gradients and inputs don't match 223 TypeError: If type of any gradient is not valid for its input. 224 """ 225 if len(grad_ys) != len(ys): 226 raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys))) 227 grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y") 228 new_grad_ys = [] 229 for i in xrange(len(grad_ys)): 230 grad_y = grad_ys[i] 231 y = ys[i] 232 with _maybe_colocate_with(y.op, colocate_gradients_with_ops): 233 if grad_y is None: 234 if y.dtype.is_complex: 235 raise TypeError( 236 "Gradients of complex tensors must set grad_ys (y.dtype = %r)" % 237 y.dtype) 238 new_grad_ys.append( 239 array_ops.fill( 240 array_ops.shape(y), 241 constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i))) 242 continue 243 if y.dtype.is_floating or y.dtype.is_integer: 244 if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer: 245 raise TypeError("Gradient type %s generated for real or " 246 "integer-valued tensor %s with type %s must be " 247 "real or integer" % 248 (dtypes.as_dtype(grad_y.dtype).name, y, 249 dtypes.as_dtype(y.dtype).name)) 250 elif y.dtype.is_complex: 251 if not grad_y.dtype.is_complex: 252 raise TypeError("Gradient type %s generated for complex-valued " 253 "tensor %s with type %s must be real" % 254 (dtypes.as_dtype(grad_y.dtype).name, y, 255 dtypes.as_dtype(y.dtype).name)) 256 else: 257 raise TypeError("Tensor %s with type %s must be numeric " 258 "to obtain a default gradient" % 259 (y, dtypes.as_dtype(y.dtype).name)) 260 # Create a grad_y tensor in the name scope of the gradient. 261 # Required for TensorArrays to identify which gradient call a 262 # grad_y value is coming from. 263 if isinstance(grad_y, ops.IndexedSlices): 264 new_grad_ys.append( 265 ops.IndexedSlices( 266 indices=(array_ops.identity( 267 grad_y.indices, name="grad_ys_%d_indices" % i) 268 if isinstance(grad_y.indices, ops.Tensor) else 269 grad_y.indices), 270 values=(array_ops.identity( 271 grad_y.values, name="grad_ys_%d_values" % i) if isinstance( 272 grad_y.values, ops.Tensor) else grad_y.values), 273 dense_shape=(array_ops.identity( 274 grad_y.dense_shape, name="grad_ys_%d_shape" % i) 275 if isinstance(grad_y.dense_shape, ops.Tensor) else 276 grad_y.dense_shape))) 277 else: 278 new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i)) 279 280 return new_grad_ys 281 282 283def _IsTrainable(tensor): 284 dtype = dtypes.as_dtype(tensor.dtype) 285 return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64, 286 dtypes.complex64, dtypes.complex128) 287 288 289def _VerifyGeneratedGradients(grads, op): 290 """Verify that gradients are valid in number and type. 291 292 Args: 293 grads: List of generated gradients. 294 op: Operation for which the gradients where generated. 295 296 Raises: 297 ValueError: if sizes of gradients and inputs don't match. 298 TypeError: if type of any gradient is not valid for its input. 299 """ 300 if len(grads) != len(op.inputs): 301 raise ValueError("Num gradients %d generated for op %s do not match num " 302 "inputs %d" % (len(grads), op.node_def, len(op.inputs))) 303 304 305def _StopOps(from_ops, stop_gradient_ops, pending_count): 306 """The set of ops that terminate the gradient computation. 307 308 This computes the frontier of the forward graph *before* which backprop 309 should stop. Operations in the returned set will not be differentiated. 310 This set is defined as the subset of `from_ops` containing ops that have 311 no predecessor in `from_ops`. `pending_count` is the result of 312 `_PendingCount(g, xs, from_ops)`. An 'op' has predecessors in `from_ops` 313 iff pending_count[op._id] > 0. 314 315 In addition, none of `stop_gradient_ops` will be differentiated. 316 317 Args: 318 from_ops: list of Operations. 319 stop_gradient_ops: list of Operations never to backprop through. 320 pending_count: List of integers, indexed by operation id. 321 322 Returns: 323 The set of operations. 324 """ 325 stop_ops = set() 326 for op in from_ops: 327 is_stop_op = True 328 for inp in op.inputs: 329 if pending_count[inp.op._id] > 0: 330 is_stop_op = False 331 break 332 if is_stop_op: 333 stop_ops.add(op._id) 334 stop_ops.update(op._id for op in stop_gradient_ops) # pylint: disable=protected-access 335 return stop_ops 336 337 338@contextlib.contextmanager 339def _maybe_colocate_with(op, colocate_gradients_with_ops): 340 """Context to colocate with `op` if `colocate_gradients_with_ops`.""" 341 if colocate_gradients_with_ops: 342 with ops.colocate_with(op): 343 yield 344 else: 345 yield 346 347 348def _SymGrad(op, out_grads): 349 """Backprop through a function call node op given its outputs' gradients.""" 350 f_in = [x for x in op.inputs] + out_grads 351 f_types = [x.dtype for x in op.inputs] 352 f = attr_value_pb2.NameAttrList() 353 f.name = op.type 354 for k in op.node_def.attr: 355 f.attr[k].CopyFrom(op.node_def.attr[k]) 356 # pylint: disable=protected-access 357 in_grads = functional_ops._symbolic_gradient(input=f_in, Tout=f_types, f=f) 358 # pylint: enable=protected-access 359 return in_grads 360 361 362def _MaybeCompile(scope, op, func, grad_fn): 363 """Compile the calculation in grad_fn if op was marked as compiled.""" 364 scope = scope.rstrip("/").replace("/", "_") 365 if func is not None: 366 xla_compile = func.definition.attr["_XlaCompile"].b 367 xla_separate_compiled_gradients = func.definition.attr[ 368 "_XlaSeparateCompiledGradients"].b 369 xla_scope = func.definition.attr["_XlaScope"].s.decode() 370 else: 371 try: 372 xla_compile = op.get_attr("_XlaCompile") 373 xla_separate_compiled_gradients = op.get_attr( 374 "_XlaSeparateCompiledGradients") 375 xla_scope = op.get_attr("_XlaScope").decode() 376 except ValueError: 377 return grad_fn() # Exit early 378 379 if not xla_compile: 380 return grad_fn() # Exit early 381 382 # If the gradients are supposed to be compiled separately, we give them a 383 # _XlaScope name that is based on the name_scope of the gradients. Otherwise 384 # they just inherit the existing _XlaScope name, which lets them be merged 385 # together with the non-gradient computation. 386 if xla_separate_compiled_gradients: 387 xla_grad_scope = "%s_grad_%s" % (xla_scope, scope) 388 else: 389 xla_grad_scope = xla_scope 390 391 attrs = { 392 "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile), 393 "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode()) 394 } 395 with ops.get_default_graph()._attr_scope(attrs): # pylint: disable=protected-access 396 return grad_fn() 397 398 399@tf_export("gradients") 400def gradients(ys, 401 xs, 402 grad_ys=None, 403 name="gradients", 404 colocate_gradients_with_ops=False, 405 gate_gradients=False, 406 aggregation_method=None, 407 stop_gradients=None): 408 """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`. 409 410 `ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys` 411 is a list of `Tensor`, holding the gradients received by the 412 `ys`. The list must be the same length as `ys`. 413 414 `gradients()` adds ops to the graph to output the derivatives of `ys` with 415 respect to `xs`. It returns a list of `Tensor` of length `len(xs)` where 416 each tensor is the `sum(dy/dx)` for y in `ys`. 417 418 `grad_ys` is a list of tensors of the same length as `ys` that holds 419 the initial gradients for each y in `ys`. When `grad_ys` is None, 420 we fill in a tensor of '1's of the shape of y for each y in `ys`. A 421 user can provide their own initial `grad_ys` to compute the 422 derivatives using a different initial gradient for each y (e.g., if 423 one wanted to weight the gradient differently for each value in 424 each y). 425 426 `stop_gradients` is a `Tensor` or a list of tensors to be considered constant 427 with respect to all `xs`. These tensors will not be backpropagated through, 428 as though they had been explicitly disconnected using `stop_gradient`. Among 429 other things, this allows computation of partial derivatives as opposed to 430 total derivatives. For example: 431 432 ```python 433 a = tf.constant(0.) 434 b = 2 * a 435 g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) 436 ``` 437 438 Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the 439 total derivatives `tf.gradients(a + b, [a, b])`, which take into account the 440 influence of `a` on `b` and evaluate to `[3.0, 1.0]`. Note that the above is 441 equivalent to: 442 443 ```python 444 a = tf.stop_gradient(tf.constant(0.)) 445 b = tf.stop_gradient(2 * a) 446 g = tf.gradients(a + b, [a, b]) 447 ``` 448 449 `stop_gradients` provides a way of stopping gradient after the graph has 450 already been constructed, as compared to `tf.stop_gradient` which is used 451 during graph construction. When the two approaches are combined, 452 backpropagation stops at both `tf.stop_gradient` nodes and nodes in 453 `stop_gradients`, whichever is encountered first. 454 455 Args: 456 ys: A `Tensor` or list of tensors to be differentiated. 457 xs: A `Tensor` or list of tensors to be used for differentiation. 458 grad_ys: Optional. A `Tensor` or list of tensors the same size as 459 `ys` and holding the gradients computed for each y in `ys`. 460 name: Optional name to use for grouping all the gradient ops together. 461 defaults to 'gradients'. 462 colocate_gradients_with_ops: If True, try colocating gradients with 463 the corresponding op. 464 gate_gradients: If True, add a tuple around the gradients returned 465 for an operations. This avoids some race conditions. 466 aggregation_method: Specifies the method used to combine gradient terms. 467 Accepted values are constants defined in the class `AggregationMethod`. 468 stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate 469 through. 470 471 Returns: 472 A list of `sum(dy/dx)` for each x in `xs`. 473 474 Raises: 475 LookupError: if one of the operations between `x` and `y` does not 476 have a registered gradient function. 477 ValueError: if the arguments are invalid. 478 RuntimeError: if called in Eager mode. 479 480 """ 481 if context.in_eager_mode(): 482 raise RuntimeError("tf.gradients not supported in EAGER mode. Use " 483 "functions in tf.contrib.eager.backprop instead.") 484 ys = _AsList(ys) 485 xs = _AsList(xs) 486 stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients) 487 if grad_ys is None: 488 grad_ys = [None] * len(ys) 489 else: 490 grad_ys = _AsList(grad_ys) 491 492 with ops.name_scope( 493 name, "gradients", 494 list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope: 495 ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y") 496 xs = [ 497 x.handle if resource_variable_ops.is_resource_variable(x) else x 498 for x in xs 499 ] 500 xs = ops.internal_convert_n_to_tensor_or_indexed_slices( 501 xs, name="x", as_ref=True) 502 grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops) 503 504 # The approach we take here is as follows: Create a list of all ops in the 505 # subgraph between the ys and xs. Visit these ops in reverse order of ids 506 # to ensure that when we visit an op the gradients w.r.t its outputs have 507 # been collected. Then aggregate these gradients if needed, call the op's 508 # gradient function, and add the generated gradients to the gradients for 509 # its input. 510 511 # Initialize the pending count for ops in the connected subgraph from ys 512 # to the xs. 513 if len(ys) > 1: 514 ys = [array_ops.identity(y) if y.consumers() else y for y in ys] 515 to_ops = [t.op for t in ys] 516 from_ops = [t.op for t in xs] 517 stop_gradient_ops = [t.op for t in stop_gradients] 518 pending_count, loop_state = _PendingCount( 519 ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops) 520 521 # Iterate over the collected ops. 522 # 523 # grads: op => list of gradients received on each output endpoint of the 524 # op. The gradients for each endpoint are initially collected as a list. 525 # When it is time to call the op's gradient function, for each endpoint we 526 # aggregate the list of received gradients into a Add() Operation if there 527 # is more than one. 528 grads = {} 529 530 # Add the initial gradients for the ys. 531 for y, grad_y in zip(ys, grad_ys): 532 _SetGrad(grads, y, grad_y) 533 534 # Initialize queue with to_ops. 535 queue = collections.deque() 536 # Add the ops in 'to_ops' into the queue. 537 to_ops_set = set() 538 for op in to_ops: 539 # 'ready' handles the case where one output gradient relies on 540 # another output's gradient. 541 # pylint: disable=protected-access 542 ready = (pending_count[op._id] == 0) 543 if ready and op._id not in to_ops_set: 544 to_ops_set.add(op._id) 545 queue.append(op) 546 # pylint: enable=protected-access 547 548 if loop_state: 549 loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set) 550 for y in loop_exits: 551 if _IsTrainable(y): 552 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) 553 queue.append(y.op) 554 555 stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count) 556 while queue: 557 # generate gradient subgraph for op. 558 op = queue.popleft() 559 with _maybe_colocate_with(op, colocate_gradients_with_ops): 560 if loop_state: 561 loop_state.EnterGradWhileContext(op, before=True) 562 out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method) 563 if loop_state: 564 loop_state.ExitGradWhileContext(op, before=True) 565 566 grad_fn = None 567 # pylint: disable=protected-access 568 func_call = None 569 is_func_call = ops.get_default_graph()._is_function(op.type) 570 has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) 571 if has_out_grads and (op._id not in stop_ops): 572 if is_func_call: 573 func_call = ops.get_default_graph()._get_function(op.type) 574 grad_fn = func_call.python_grad_func 575 # pylint: enable=protected-access 576 else: 577 # A grad_fn must be defined, either as a function or as None 578 # for ops that do not have gradients. 579 try: 580 grad_fn = ops.get_gradient_function(op) 581 except LookupError: 582 raise LookupError( 583 "No gradient defined for operation '%s' (op type: %s)" % 584 (op.name, op.type)) 585 if loop_state: 586 loop_state.EnterGradWhileContext(op, before=False) 587 if (grad_fn or is_func_call) and has_out_grads: 588 # NOTE: If _AggregatedGrads didn't compute a value for the i'th 589 # output, it means that the cost does not depend on output[i], 590 # therefore dC/doutput[i] is 0. 591 for i, out_grad in enumerate(out_grads): 592 if (not isinstance(out_grad, ops.Tensor) and not out_grad) and ( 593 (not grad_fn and is_func_call) or _IsTrainable(op.outputs[i])): 594 # Only trainable outputs or outputs for a function call that 595 # will use SymbolicGradient get a zero gradient. Gradient 596 # functions should ignore the gradient for other outputs. 597 # TODO(apassos) gradients of resource handles might be an 598 # issue here because of zeros. 599 if loop_state: 600 out_grads[i] = loop_state.ZerosLike(op, i) 601 else: 602 out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i) 603 with ops.name_scope(op.name + "_grad"): 604 # pylint: disable=protected-access 605 with ops.get_default_graph()._original_op(op): 606 # pylint: enable=protected-access 607 if grad_fn: 608 # If grad_fn was found, do not use SymbolicGradient even for 609 # functions. 610 in_grads = _MaybeCompile(grad_scope, op, func_call, 611 lambda: grad_fn(op, *out_grads)) 612 else: 613 # For function call ops, we add a 'SymbolicGradient' 614 # node to the graph to compute gradients. 615 in_grads = _MaybeCompile(grad_scope, op, func_call, 616 lambda: _SymGrad(op, out_grads)) 617 in_grads = _AsList(in_grads) 618 _VerifyGeneratedGradients(in_grads, op) 619 if gate_gradients and len([x for x in in_grads 620 if x is not None]) > 1: 621 with ops.device(None): 622 with ops.colocate_with(None, ignore_existing=True): 623 in_grads = control_flow_ops.tuple(in_grads) 624 _LogOpGradients(op, out_grads, in_grads) 625 else: 626 # If no grad_fn is defined or none of out_grads is available, 627 # just propagate a list of None backwards. 628 in_grads = [None] * len(op.inputs) 629 for i, (t_in, in_grad) in enumerate(zip(op.inputs, in_grads)): 630 if in_grad is not None: 631 if (isinstance(in_grad, ops.Tensor) and 632 t_in.dtype != dtypes.resource): 633 try: 634 in_grad.set_shape(t_in.get_shape()) 635 except ValueError: 636 raise ValueError( 637 "Incompatible shapes between op input and calculated " 638 "input gradient. Forward operation: %s. Input index: %d. " 639 "Original input shape: %s. " 640 "Calculated input gradient shape: %s" % 641 (op.name, i, t_in.shape, in_grad.shape)) 642 _SetGrad(grads, t_in, in_grad) 643 if loop_state: 644 loop_state.ExitGradWhileContext(op, before=False) 645 646 # Update pending count for the inputs of op and enqueue ready ops. 647 _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state) 648 649 if loop_state: 650 loop_state.PostProcessing() 651 return [_GetGrad(grads, x) for x in xs] 652 653 654def _HasAnyNotNoneGrads(grads, op): 655 """Return true iff op has real gradient.""" 656 out_grads = _GetGrads(grads, op) 657 for out_grad in out_grads: 658 if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)): 659 return True 660 if out_grad and isinstance(out_grad, collections.Sequence): 661 if any([g is not None for g in out_grad]): 662 return True 663 return False 664 665 666def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state): 667 """Update pending count for the inputs of op and enqueue ready ops.""" 668 for x in op.inputs: 669 # pylint: disable=protected-access 670 pending_count[x.op._id] -= 1 671 ready = (pending_count[x.op._id] == 0) 672 if loop_state and not ready: 673 ready = ( 674 pending_count[x.op._id] > 0 and control_flow_util.IsLoopSwitch(x.op)) 675 # pylint: enable=protected-access 676 if ready: 677 if control_flow_util.IsLoopExit(x.op): 678 # if x is an exit without real gradient, defer processing them. 679 grad_state = loop_state.GetGradState(x.op, before=False) 680 grad_state.deferred_exits.append(x) 681 grad_state.pending_exits_count -= 1 682 if grad_state.pending_exits_count == 0: 683 # We now have all the exits so process them. 684 has_not_none_grad = False 685 for y in grad_state.deferred_exits: 686 if _HasAnyNotNoneGrads(grads, y.op): 687 has_not_none_grad = True 688 queue.append(y.op) 689 else: 690 grad_state.unused_exits.append(y) 691 if has_not_none_grad: 692 # For an unused exit, if it has trainable outputs, backprop 693 # a zero gradient. Otherwise, just ignore it. 694 for y in grad_state.unused_exits: 695 if _IsTrainable(y): 696 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) 697 queue.append(y.op) 698 else: 699 # All exits are "unused" so use None as gradient. 700 for y in grad_state.unused_exits: 701 queue.append(y.op) 702 else: 703 queue.append(x.op) 704 705 706def _SetGrad(grads, t, grad): 707 """Sets gradient "grad" in "grads" for tensor "t".""" 708 op = t.op 709 op_grads = grads.get(op) 710 if not op_grads: 711 op_grads = [[] for _ in xrange(len(op.outputs))] 712 grads[op] = op_grads 713 t_grads = op_grads[t.value_index] 714 if isinstance(t_grads, list): 715 t_grads.append(grad) 716 else: 717 assert control_flow_util.IsLoopSwitch(op) 718 op_grads[t.value_index] = grad 719 720 721def _GetGrad(grads, t): 722 """Gets gradient for tensor "t".""" 723 op = t.op 724 op_grads = grads.get(op) 725 if not op_grads: 726 return None 727 t_grad = op_grads[t.value_index] 728 assert not isinstance( 729 t_grad, list), ("gradients list should have been aggregated by now.") 730 return t_grad 731 732 733def _GetGrads(grads, op): 734 """Gets all gradients for op.""" 735 if op in grads: 736 return grads[op] 737 else: 738 return [[] for _ in xrange(len(op.outputs))] 739 740 741def _HandleNestedIndexedSlices(grad): 742 assert isinstance(grad, ops.IndexedSlices) 743 if isinstance(grad.values, ops.Tensor): 744 return grad 745 else: 746 assert isinstance(grad.values, ops.IndexedSlices) 747 g = _HandleNestedIndexedSlices(grad.values) 748 return ops.IndexedSlices(g.values, array_ops.gather( 749 grad.indices, g.indices), g.dense_shape) 750 751 752def _AccumulatorShape(inputs): 753 shape = tensor_shape.unknown_shape() 754 for i in inputs: 755 if isinstance(i, ops.Tensor): 756 shape = shape.merge_with(i.get_shape()) 757 return shape 758 759 760def _LogOpGradients(op, out_grads, in_grads): 761 """Log the in and out grads of an op.""" 762 logging.vlog(1, "Gradient for '" + op.name + "'") 763 764 def _FilterGrad(x): 765 if x is None: 766 return False 767 if isinstance(x, (list, tuple)): 768 return bool(x) 769 else: 770 return True 771 772 logging.vlog(1, " in --> %s", 773 ", ".join([x.name for x in out_grads if _FilterGrad(x)])) 774 logging.vlog(1, " out --> %s", 775 ", ".join([x.name for x in in_grads if _FilterGrad(x)])) 776 777 778def _MultiDeviceAddN(tensor_list): 779 """Adds tensors from potentially multiple devices.""" 780 # Basic function structure comes from control_flow_ops.group(). 781 # Sort tensors according to their devices. 782 tensors_on_device = collections.defaultdict(lambda: []) 783 for tensor in tensor_list: 784 tensors_on_device[tensor.device].append(tensor) 785 786 # For each device, add the tensors on that device first. 787 # Then gather the partial sums from multiple devices. 788 # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion. 789 # E.g., aggregate per GPU, then per task, and so on. 790 summands = [] 791 792 def DeviceKey(dev): 793 return "" if dev is None else dev 794 795 for dev in sorted(six.iterkeys(tensors_on_device), key=DeviceKey): 796 tensors = tensors_on_device[dev] 797 with ops.colocate_with(tensors[0].op, ignore_existing=True): 798 summands.append(math_ops.add_n(tensors)) 799 800 return math_ops.add_n(summands) 801 802 803@tf_export("AggregationMethod") 804class AggregationMethod(object): 805 """A class listing aggregation methods used to combine gradients. 806 807 Computing partial derivatives can require aggregating gradient 808 contributions. This class lists the various methods that can 809 be used to combine gradients in the graph: 810 811 * `ADD_N`: All of the gradient terms are summed as part of one 812 operation using the "AddN" op. It has the property that all 813 gradients must be ready before any aggregation is performed. 814 * `DEFAULT`: The system-chosen default aggregation method. 815 """ 816 ADD_N = 0 817 DEFAULT = ADD_N 818 # The following are experimental and may not be supported in future releases. 819 EXPERIMENTAL_TREE = 1 820 EXPERIMENTAL_ACCUMULATE_N = 2 821 822 823def _AggregatedGrads(grads, op, loop_state, aggregation_method=None): 824 """Get the aggregated gradients for op. 825 826 Args: 827 grads: The map of memoized gradients. 828 op: The op to get gradients for. 829 loop_state: An object for maintaining the state of the while loops in the 830 graph. It is of type ControlFlowState. None if the graph 831 contains no while loops. 832 aggregation_method: Specifies the method used to combine gradient terms. 833 Accepted values are constants defined in the class `AggregationMethod`. 834 835 Returns: 836 A list of gradients, one per each output of `op`. If the gradients 837 for a particular output is a list, this function aggregates it 838 before returning. 839 840 Raises: 841 TypeError: if the incoming grads are not Tensors or IndexedSlices. 842 ValueError: if the arguments are invalid. 843 844 """ 845 if aggregation_method is None: 846 aggregation_method = AggregationMethod.DEFAULT 847 if aggregation_method not in [ 848 AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE, 849 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 850 ]: 851 raise ValueError( 852 "Invalid aggregation_method specified %s." % aggregation_method) 853 out_grads = _GetGrads(grads, op) 854 for i, out_grad in enumerate(out_grads): 855 if loop_state: 856 if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)): 857 assert control_flow_util.IsLoopSwitch(op) 858 continue 859 # Grads have to be Tensors or IndexedSlices 860 if (isinstance(out_grad, collections.Sequence) and not all([ 861 isinstance(g, (ops.Tensor, ops.IndexedSlices)) 862 for g in out_grad 863 if g is not None 864 ])): 865 raise TypeError("gradients have to be either all Tensors " 866 "or all IndexedSlices") 867 # Aggregate multiple gradients, and convert [] to None. 868 if out_grad: 869 if len(out_grad) < 2: 870 used = "nop" 871 out_grads[i] = out_grad[0] 872 elif all([isinstance(g, ops.Tensor) for g in out_grad if g is not None]): 873 tensor_shape = _AccumulatorShape(out_grad) 874 if (aggregation_method == AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 875 and len(out_grad) > 2 and tensor_shape.is_fully_defined()): 876 # The benefit of using AccumulateN is that its inputs can be combined 877 # in any order and this can allow the expression to be evaluated with 878 # a smaller memory footprint. When used with gpu_allocator_retry, 879 # it is possible to compute a sum of terms which are much larger than 880 # total GPU memory. 881 # AccumulateN can currently only be used if we know the shape for 882 # an accumulator variable. If this is not known, or if we only have 883 # 2 grads then we fall through to the "tree" case below. 884 used = "accumulate_n" 885 out_grads[i] = math_ops.accumulate_n(out_grad) 886 elif aggregation_method in [ 887 AggregationMethod.EXPERIMENTAL_TREE, 888 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 889 ]: 890 # Aggregate all gradients by doing pairwise sums: this may 891 # reduce performance, but it can improve memory because the 892 # gradients can be released earlier. 893 # 894 # TODO(vrv): Consider replacing this with a version of 895 # tf.AddN() that eagerly frees its inputs as soon as they are 896 # ready, so the order of this tree does not become a problem. 897 used = "tree" 898 with ops.name_scope(op.name + "_gradient_sum"): 899 running_sum = out_grad[0] 900 for grad in out_grad[1:]: 901 running_sum = math_ops.add_n([running_sum, grad]) 902 out_grads[i] = running_sum 903 else: 904 used = "add_n" 905 out_grads[i] = _MultiDeviceAddN(out_grad) 906 logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad), 907 tensor_shape, used) 908 else: 909 out_grad = math_ops._as_indexed_slices_list( 910 [g for g in out_grad if g is not None]) 911 out_grad = [_HandleNestedIndexedSlices(x) for x in out_grad] 912 # Form IndexedSlices out of the concatenated values and 913 # indices. 914 out_grads[i] = ops.IndexedSlices( 915 array_ops.concat([x.values for x in out_grad], 0), 916 array_ops.concat([x.indices for x in out_grad], 0), 917 out_grad[0].dense_shape) 918 else: # not out_grad 919 # out_grads[i] is [], thus its aggregation is simply None. 920 out_grads[i] = None 921 return out_grads 922 923 924# TODO(vrv): Make this available when we want to make it public. 925def _hessian_vector_product(ys, xs, v): 926 """Multiply the Hessian of `ys` wrt `xs` by `v`. 927 928 This is an efficient construction that uses a backprop-like approach 929 to compute the product between the Hessian and another vector. The 930 Hessian is usually too large to be explicitly computed or even 931 represented, but this method allows us to at least multiply by it 932 for the same big-O cost as backprop. 933 934 Implicit Hessian-vector products are the main practical, scalable way 935 of using second derivatives with neural networks. They allow us to 936 do things like construct Krylov subspaces and approximate conjugate 937 gradient descent. 938 939 Example: if `y` = 1/2 `x`^T A `x`, then `hessian_vector_product(y, 940 x, v)` will return an expression that evaluates to the same values 941 as (A + A.T) `v`. 942 943 Args: 944 ys: A scalar value, or a tensor or list of tensors to be summed to 945 yield a scalar. 946 xs: A list of tensors that we should construct the Hessian over. 947 v: A list of tensors, with the same shapes as xs, that we want to 948 multiply by the Hessian. 949 950 Returns: 951 A list of tensors (or if the list would be length 1, a single tensor) 952 containing the product between the Hessian and `v`. 953 954 Raises: 955 ValueError: `xs` and `v` have different length. 956 957 """ 958 959 # Validate the input 960 length = len(xs) 961 if len(v) != length: 962 raise ValueError("xs and v must have the same length.") 963 964 # First backprop 965 grads = gradients(ys, xs) 966 967 assert len(grads) == length 968 elemwise_products = [ 969 math_ops.multiply(grad_elem, array_ops.stop_gradient(v_elem)) 970 for grad_elem, v_elem in zip(grads, v) 971 if grad_elem is not None 972 ] 973 974 # Second backprop 975 return gradients(elemwise_products, xs) 976 977 978@tf_export("hessians") 979def hessians(ys, 980 xs, 981 name="hessians", 982 colocate_gradients_with_ops=False, 983 gate_gradients=False, 984 aggregation_method=None): 985 """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`. 986 987 `hessians()` adds ops to the graph to output the Hessian matrix of `ys` 988 with respect to `xs`. It returns a list of `Tensor` of length `len(xs)` 989 where each tensor is the Hessian of `sum(ys)`. 990 991 The Hessian is a matrix of second-order partial derivatives of a scalar 992 tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details). 993 994 Args: 995 ys: A `Tensor` or list of tensors to be differentiated. 996 xs: A `Tensor` or list of tensors to be used for differentiation. 997 name: Optional name to use for grouping all the gradient ops together. 998 defaults to 'hessians'. 999 colocate_gradients_with_ops: See `gradients()` documentation for details. 1000 gate_gradients: See `gradients()` documentation for details. 1001 aggregation_method: See `gradients()` documentation for details. 1002 1003 Returns: 1004 A list of Hessian matrices of `sum(ys)` for each `x` in `xs`. 1005 1006 Raises: 1007 LookupError: if one of the operations between `xs` and `ys` does not 1008 have a registered gradient function. 1009 """ 1010 xs = _AsList(xs) 1011 kwargs = { 1012 "colocate_gradients_with_ops": colocate_gradients_with_ops, 1013 "gate_gradients": gate_gradients, 1014 "aggregation_method": aggregation_method 1015 } 1016 # Compute first-order derivatives and iterate for each x in xs. 1017 hessians = [] 1018 _gradients = gradients(ys, xs, **kwargs) 1019 for gradient, x in zip(_gradients, xs): 1020 # change shape to one-dimension without graph branching 1021 gradient = array_ops.reshape(gradient, [-1]) 1022 1023 # Declare an iterator and tensor array loop variables for the gradients. 1024 n = array_ops.size(x) 1025 loop_vars = [ 1026 array_ops.constant(0, dtypes.int32), 1027 tensor_array_ops.TensorArray(x.dtype, n) 1028 ] 1029 # Iterate over all elements of the gradient and compute second order 1030 # derivatives. 1031 _, hessian = control_flow_ops.while_loop( 1032 lambda j, _: j < n, 1033 lambda j, result: (j + 1, 1034 result.write(j, gradients(gradient[j], x)[0])), 1035 loop_vars 1036 ) 1037 1038 _shape = array_ops.shape(x) 1039 _reshaped_hessian = array_ops.reshape(hessian.stack(), 1040 array_ops.concat((_shape, _shape), 0)) 1041 hessians.append(_reshaped_hessian) 1042 return hessians 1043