• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements the graph generation for computation of gradients."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import warnings
24
25import numpy as np
26import six
27from six.moves import xrange  # pylint: disable=redefined-builtin
28
29from tensorflow.core.framework import attr_value_pb2
30from tensorflow.python.eager import context
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import check_ops  # pylint: disable=unused-import
39from tensorflow.python.ops import control_flow_grad  # pylint: disable=unused-import
40from tensorflow.python.ops import control_flow_ops
41from tensorflow.python.ops import control_flow_util
42from tensorflow.python.ops import functional_ops
43from tensorflow.python.ops import image_grad  # pylint: disable=unused-import
44from tensorflow.python.ops import linalg_grad  # pylint: disable=unused-import
45from tensorflow.python.ops import linalg_ops  # pylint: disable=unused-import
46from tensorflow.python.ops import logging_ops  # pylint: disable=unused-import
47from tensorflow.python.ops import manip_grad  # pylint: disable=unused-import
48from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
49from tensorflow.python.ops import math_ops
50from tensorflow.python.ops import resource_variable_ops
51from tensorflow.python.ops import spectral_grad  # pylint: disable=unused-import
52from tensorflow.python.ops import tensor_array_ops
53from tensorflow.python.platform import tf_logging as logging
54from tensorflow.python.util.tf_export import tf_export
55
56# Warn the user if we convert a sparse representation to dense with at
57# least this number of elements.
58_LARGE_SPARSE_NUM_ELEMENTS = 100000000
59
60
61def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False):
62  """Converts an IndexedSlices object `value` to a Tensor.
63
64  NOTE(mrry): This function is potentially expensive.
65
66  Args:
67    value: An ops.IndexedSlices object.
68    dtype: The dtype of the Tensor to be returned.
69    name: Optional name to use for the returned Tensor.
70    as_ref: True if a ref is requested.
71
72  Returns:
73    A dense Tensor representing the values in the given IndexedSlices.
74
75  Raises:
76    ValueError: If the IndexedSlices does not have the same dtype.
77  """
78  _ = as_ref
79  if dtype and not dtype.is_compatible_with(value.dtype):
80    raise ValueError(
81        "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" %
82        (dtype.name, value.dtype.name))
83  if value.dense_shape is None:
84    raise ValueError(
85        "Tensor conversion requested for IndexedSlices without dense_shape: %s"
86        % str(value))
87  # TODO(mrry): Consider adding static shape information to
88  # IndexedSlices, to avoid using numpy here.
89  dense_shape_value = tensor_util.constant_value(value.dense_shape)
90  if dense_shape_value is not None:
91    num_elements = np.prod(dense_shape_value)
92    if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS:
93      warnings.warn(
94          "Converting sparse IndexedSlices to a dense Tensor with %d elements. "
95          "This may consume a large amount of memory." % num_elements)
96  else:
97    warnings.warn(
98        "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
99        "This may consume a large amount of memory.")
100  return math_ops.unsorted_segment_sum(
101      value.values, value.indices, value.dense_shape[0], name=name)
102
103
104ops.register_tensor_conversion_function(ops.IndexedSlices,
105                                        _IndexedSlicesToTensor)
106
107
108def _MarkReachedOps(from_ops, reached_ops):
109  """Mark all ops reached from "from_ops".
110
111  Args:
112    from_ops: list of Operations.
113    reached_ops: list of booleans, indexed by operation id.
114  """
115  queue = collections.deque()
116  queue.extend(from_ops)
117  while queue:
118    op = queue.popleft()
119    if not reached_ops[op._id]:
120      reached_ops[op._id] = True
121      for output in op.outputs:
122        queue.extend(output.consumers())
123
124
125def _GatherInputs(to_ops, reached_ops):
126  """List all inputs of to_ops that are in reached_ops.
127
128  Args:
129    to_ops: list of Operations.
130    reached_ops: list of booleans, indexed by operation id.
131
132  Returns:
133    The list of all inputs of to_ops that are in reached_ops.
134    That list includes all elements of to_ops.
135  """
136  inputs = []
137  queue = collections.deque()
138  queue.extend(to_ops)
139  while queue:
140    op = queue.popleft()
141    # We are interested in this op.
142    if reached_ops[op._id]:
143      inputs.append(op)
144      # Clear the boolean so we won't add the inputs again.
145      reached_ops[op._id] = False
146      for inp in op.inputs:
147        queue.append(inp.op)
148  return inputs
149
150
151def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops):
152  """Initialize the pending count for ops between two lists of Operations.
153
154  'pending_count[op._id]' indicates the number of backprop inputs
155  to this operation.
156
157  Args:
158    graph: a Graph.
159    to_ops: list of Operations.
160    from_ops: list of Operations.
161    colocate_gradients_with_ops: Python bool.  See docstring of gradients().
162
163  Returns:
164    A tuple containing: (1) a list of integers indexed by operation id,
165    indicating the number of backprop inputs to this operation, and (2)
166    a ControlFlowState object which is not None if the ops between from_ops
167    and to_ops contain control flow loops.
168  """
169  # Mark reachable ops from from_ops.
170  reached_ops = [False] * (graph._last_id + 1)
171  for op in to_ops:
172    reached_ops[op._id] = True
173  _MarkReachedOps(from_ops, reached_ops)
174
175  # Mark between ops.
176  between_ops = [False] * (graph._last_id + 1)
177  between_op_list = []
178  queue = collections.deque()
179  queue.extend(to_ops)
180  while queue:
181    op = queue.popleft()
182    # We are interested in this op.
183    if reached_ops[op._id]:
184      between_ops[op._id] = True
185      between_op_list.append(op)
186      # Clear the boolean so we won't add the inputs again.
187      reached_ops[op._id] = False
188      for inp in op.inputs:
189        queue.append(inp.op)
190
191  # 'loop_state' is None if there are no while loops.
192  loop_state = control_flow_ops.MaybeCreateControlFlowState(
193      between_op_list, between_ops, colocate_gradients_with_ops)
194
195  # Initialize pending count for between ops.
196  pending_count = [0] * (graph._last_id + 1)
197  for op in between_op_list:
198    for x in op.inputs:
199      if between_ops[x.op._id]:
200        pending_count[x.op._id] += 1
201
202  return pending_count, loop_state
203
204
205def _AsList(x):
206  return x if isinstance(x, (list, tuple)) else [x]
207
208
209def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
210  """Fill in default values for grad_ys.
211
212  Args:
213    grad_ys: List of gradients, can contain None.
214    ys: List of tensors.
215    colocate_gradients_with_ops: If True, try colocating gradients with
216      the corresponding op.
217
218  Returns:
219    A list of gradients to use, without None.
220
221  Raises:
222    ValueError: If sizes of gradients and inputs don't match
223    TypeError: If type of any gradient is not valid for its input.
224  """
225  if len(grad_ys) != len(ys):
226    raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys)))
227  grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y")
228  new_grad_ys = []
229  for i in xrange(len(grad_ys)):
230    grad_y = grad_ys[i]
231    y = ys[i]
232    with _maybe_colocate_with(y.op, colocate_gradients_with_ops):
233      if grad_y is None:
234        if y.dtype.is_complex:
235          raise TypeError(
236              "Gradients of complex tensors must set grad_ys (y.dtype = %r)" %
237              y.dtype)
238        new_grad_ys.append(
239            array_ops.fill(
240                array_ops.shape(y),
241                constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i)))
242        continue
243      if y.dtype.is_floating or y.dtype.is_integer:
244        if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
245          raise TypeError("Gradient type %s generated for real or "
246                          "integer-valued tensor %s with type %s must be "
247                          "real or integer" %
248                          (dtypes.as_dtype(grad_y.dtype).name, y,
249                           dtypes.as_dtype(y.dtype).name))
250      elif y.dtype.is_complex:
251        if not grad_y.dtype.is_complex:
252          raise TypeError("Gradient type %s generated for complex-valued "
253                          "tensor %s with type %s must be real" %
254                          (dtypes.as_dtype(grad_y.dtype).name, y,
255                           dtypes.as_dtype(y.dtype).name))
256      else:
257        raise TypeError("Tensor %s with type %s must be numeric "
258                        "to obtain a default gradient" %
259                        (y, dtypes.as_dtype(y.dtype).name))
260      # Create a grad_y tensor in the name scope of the gradient.
261      # Required for TensorArrays to identify which gradient call a
262      # grad_y value is coming from.
263      if isinstance(grad_y, ops.IndexedSlices):
264        new_grad_ys.append(
265            ops.IndexedSlices(
266                indices=(array_ops.identity(
267                    grad_y.indices, name="grad_ys_%d_indices" % i)
268                         if isinstance(grad_y.indices, ops.Tensor) else
269                         grad_y.indices),
270                values=(array_ops.identity(
271                    grad_y.values, name="grad_ys_%d_values" % i) if isinstance(
272                        grad_y.values, ops.Tensor) else grad_y.values),
273                dense_shape=(array_ops.identity(
274                    grad_y.dense_shape, name="grad_ys_%d_shape" % i)
275                             if isinstance(grad_y.dense_shape, ops.Tensor) else
276                             grad_y.dense_shape)))
277      else:
278        new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i))
279
280  return new_grad_ys
281
282
283def _IsTrainable(tensor):
284  dtype = dtypes.as_dtype(tensor.dtype)
285  return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
286                              dtypes.complex64, dtypes.complex128)
287
288
289def _VerifyGeneratedGradients(grads, op):
290  """Verify that gradients are valid in number and type.
291
292  Args:
293    grads: List of generated gradients.
294    op: Operation for which the gradients where generated.
295
296  Raises:
297    ValueError: if sizes of gradients and inputs don't match.
298    TypeError: if type of any gradient is not valid for its input.
299  """
300  if len(grads) != len(op.inputs):
301    raise ValueError("Num gradients %d generated for op %s do not match num "
302                     "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
303
304
305def _StopOps(from_ops, stop_gradient_ops, pending_count):
306  """The set of ops that terminate the gradient computation.
307
308  This computes the frontier of the forward graph *before* which backprop
309  should stop. Operations in the returned set will not be differentiated.
310  This set is defined as the subset of `from_ops` containing ops that have
311  no predecessor in `from_ops`. `pending_count` is the result of
312  `_PendingCount(g, xs, from_ops)`. An 'op' has predecessors in `from_ops`
313  iff pending_count[op._id] > 0.
314
315  In addition, none of `stop_gradient_ops` will be differentiated.
316
317  Args:
318    from_ops: list of Operations.
319    stop_gradient_ops: list of Operations never to backprop through.
320    pending_count: List of integers, indexed by operation id.
321
322  Returns:
323    The set of operations.
324  """
325  stop_ops = set()
326  for op in from_ops:
327    is_stop_op = True
328    for inp in op.inputs:
329      if pending_count[inp.op._id] > 0:
330        is_stop_op = False
331        break
332    if is_stop_op:
333      stop_ops.add(op._id)
334  stop_ops.update(op._id for op in stop_gradient_ops)  # pylint: disable=protected-access
335  return stop_ops
336
337
338@contextlib.contextmanager
339def _maybe_colocate_with(op, colocate_gradients_with_ops):
340  """Context to colocate with `op` if `colocate_gradients_with_ops`."""
341  if colocate_gradients_with_ops:
342    with ops.colocate_with(op):
343      yield
344  else:
345    yield
346
347
348def _SymGrad(op, out_grads):
349  """Backprop through a function call node op given its outputs' gradients."""
350  f_in = [x for x in op.inputs] + out_grads
351  f_types = [x.dtype for x in op.inputs]
352  f = attr_value_pb2.NameAttrList()
353  f.name = op.type
354  for k in op.node_def.attr:
355    f.attr[k].CopyFrom(op.node_def.attr[k])
356  # pylint: disable=protected-access
357  in_grads = functional_ops._symbolic_gradient(input=f_in, Tout=f_types, f=f)
358  # pylint: enable=protected-access
359  return in_grads
360
361
362def _MaybeCompile(scope, op, func, grad_fn):
363  """Compile the calculation in grad_fn if op was marked as compiled."""
364  scope = scope.rstrip("/").replace("/", "_")
365  if func is not None:
366    xla_compile = func.definition.attr["_XlaCompile"].b
367    xla_separate_compiled_gradients = func.definition.attr[
368        "_XlaSeparateCompiledGradients"].b
369    xla_scope = func.definition.attr["_XlaScope"].s.decode()
370  else:
371    try:
372      xla_compile = op.get_attr("_XlaCompile")
373      xla_separate_compiled_gradients = op.get_attr(
374          "_XlaSeparateCompiledGradients")
375      xla_scope = op.get_attr("_XlaScope").decode()
376    except ValueError:
377      return grad_fn()  # Exit early
378
379  if not xla_compile:
380    return grad_fn()  # Exit early
381
382  # If the gradients are supposed to be compiled separately, we give them a
383  # _XlaScope name that is based on the name_scope of the gradients.  Otherwise
384  # they just inherit the existing _XlaScope name, which lets them be merged
385  # together with the non-gradient computation.
386  if xla_separate_compiled_gradients:
387    xla_grad_scope = "%s_grad_%s" % (xla_scope, scope)
388  else:
389    xla_grad_scope = xla_scope
390
391  attrs = {
392      "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile),
393      "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode())
394  }
395  with ops.get_default_graph()._attr_scope(attrs):  # pylint: disable=protected-access
396    return grad_fn()
397
398
399@tf_export("gradients")
400def gradients(ys,
401              xs,
402              grad_ys=None,
403              name="gradients",
404              colocate_gradients_with_ops=False,
405              gate_gradients=False,
406              aggregation_method=None,
407              stop_gradients=None):
408  """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
409
410  `ys` and `xs` are each a `Tensor` or a list of tensors.  `grad_ys`
411  is a list of `Tensor`, holding the gradients received by the
412  `ys`. The list must be the same length as `ys`.
413
414  `gradients()` adds ops to the graph to output the derivatives of `ys` with
415  respect to `xs`.  It returns a list of `Tensor` of length `len(xs)` where
416  each tensor is the `sum(dy/dx)` for y in `ys`.
417
418  `grad_ys` is a list of tensors of the same length as `ys` that holds
419  the initial gradients for each y in `ys`.  When `grad_ys` is None,
420  we fill in a tensor of '1's of the shape of y for each y in `ys`.  A
421  user can provide their own initial `grad_ys` to compute the
422  derivatives using a different initial gradient for each y (e.g., if
423  one wanted to weight the gradient differently for each value in
424  each y).
425
426  `stop_gradients` is a `Tensor` or a list of tensors to be considered constant
427  with respect to all `xs`. These tensors will not be backpropagated through,
428  as though they had been explicitly disconnected using `stop_gradient`.  Among
429  other things, this allows computation of partial derivatives as opposed to
430  total derivatives. For example:
431
432  ```python
433  a = tf.constant(0.)
434  b = 2 * a
435  g = tf.gradients(a + b, [a, b], stop_gradients=[a, b])
436  ```
437
438  Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the
439  total derivatives `tf.gradients(a + b, [a, b])`, which take into account the
440  influence of `a` on `b` and evaluate to `[3.0, 1.0]`.  Note that the above is
441  equivalent to:
442
443  ```python
444  a = tf.stop_gradient(tf.constant(0.))
445  b = tf.stop_gradient(2 * a)
446  g = tf.gradients(a + b, [a, b])
447  ```
448
449  `stop_gradients` provides a way of stopping gradient after the graph has
450  already been constructed, as compared to `tf.stop_gradient` which is used
451  during graph construction.  When the two approaches are combined,
452  backpropagation stops at both `tf.stop_gradient` nodes and nodes in
453  `stop_gradients`, whichever is encountered first.
454
455  Args:
456    ys: A `Tensor` or list of tensors to be differentiated.
457    xs: A `Tensor` or list of tensors to be used for differentiation.
458    grad_ys: Optional. A `Tensor` or list of tensors the same size as
459      `ys` and holding the gradients computed for each y in `ys`.
460    name: Optional name to use for grouping all the gradient ops together.
461      defaults to 'gradients'.
462    colocate_gradients_with_ops: If True, try colocating gradients with
463      the corresponding op.
464    gate_gradients: If True, add a tuple around the gradients returned
465      for an operations.  This avoids some race conditions.
466    aggregation_method: Specifies the method used to combine gradient terms.
467      Accepted values are constants defined in the class `AggregationMethod`.
468    stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
469      through.
470
471  Returns:
472    A list of `sum(dy/dx)` for each x in `xs`.
473
474  Raises:
475    LookupError: if one of the operations between `x` and `y` does not
476      have a registered gradient function.
477    ValueError: if the arguments are invalid.
478    RuntimeError: if called in Eager mode.
479
480  """
481  if context.in_eager_mode():
482    raise RuntimeError("tf.gradients not supported in EAGER mode. Use "
483                       "functions in tf.contrib.eager.backprop instead.")
484  ys = _AsList(ys)
485  xs = _AsList(xs)
486  stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
487  if grad_ys is None:
488    grad_ys = [None] * len(ys)
489  else:
490    grad_ys = _AsList(grad_ys)
491
492  with ops.name_scope(
493      name, "gradients",
494      list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
495    ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
496    xs = [
497        x.handle if resource_variable_ops.is_resource_variable(x) else x
498        for x in xs
499    ]
500    xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
501        xs, name="x", as_ref=True)
502    grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
503
504    # The approach we take here is as follows: Create a list of all ops in the
505    # subgraph between the ys and xs.  Visit these ops in reverse order of ids
506    # to ensure that when we visit an op the gradients w.r.t its outputs have
507    # been collected.  Then aggregate these gradients if needed, call the op's
508    # gradient function, and add the generated gradients to the gradients for
509    # its input.
510
511    # Initialize the pending count for ops in the connected subgraph from ys
512    # to the xs.
513    if len(ys) > 1:
514      ys = [array_ops.identity(y) if y.consumers() else y for y in ys]
515    to_ops = [t.op for t in ys]
516    from_ops = [t.op for t in xs]
517    stop_gradient_ops = [t.op for t in stop_gradients]
518    pending_count, loop_state = _PendingCount(
519        ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops)
520
521    # Iterate over the collected ops.
522    #
523    # grads: op => list of gradients received on each output endpoint of the
524    # op.  The gradients for each endpoint are initially collected as a list.
525    # When it is time to call the op's gradient function, for each endpoint we
526    # aggregate the list of received gradients into a Add() Operation if there
527    # is more than one.
528    grads = {}
529
530    # Add the initial gradients for the ys.
531    for y, grad_y in zip(ys, grad_ys):
532      _SetGrad(grads, y, grad_y)
533
534    # Initialize queue with to_ops.
535    queue = collections.deque()
536    # Add the ops in 'to_ops' into the queue.
537    to_ops_set = set()
538    for op in to_ops:
539      # 'ready' handles the case where one output gradient relies on
540      # another output's gradient.
541      # pylint: disable=protected-access
542      ready = (pending_count[op._id] == 0)
543      if ready and op._id not in to_ops_set:
544        to_ops_set.add(op._id)
545        queue.append(op)
546      # pylint: enable=protected-access
547
548    if loop_state:
549      loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
550      for y in loop_exits:
551        if _IsTrainable(y):
552          _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
553          queue.append(y.op)
554
555    stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count)
556    while queue:
557      # generate gradient subgraph for op.
558      op = queue.popleft()
559      with _maybe_colocate_with(op, colocate_gradients_with_ops):
560        if loop_state:
561          loop_state.EnterGradWhileContext(op, before=True)
562        out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
563        if loop_state:
564          loop_state.ExitGradWhileContext(op, before=True)
565
566        grad_fn = None
567        # pylint: disable=protected-access
568        func_call = None
569        is_func_call = ops.get_default_graph()._is_function(op.type)
570        has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
571        if has_out_grads and (op._id not in stop_ops):
572          if is_func_call:
573            func_call = ops.get_default_graph()._get_function(op.type)
574            grad_fn = func_call.python_grad_func
575            # pylint: enable=protected-access
576          else:
577            # A grad_fn must be defined, either as a function or as None
578            # for ops that do not have gradients.
579            try:
580              grad_fn = ops.get_gradient_function(op)
581            except LookupError:
582              raise LookupError(
583                  "No gradient defined for operation '%s' (op type: %s)" %
584                  (op.name, op.type))
585        if loop_state:
586          loop_state.EnterGradWhileContext(op, before=False)
587        if (grad_fn or is_func_call) and has_out_grads:
588          # NOTE: If _AggregatedGrads didn't compute a value for the i'th
589          # output, it means that the cost does not depend on output[i],
590          # therefore dC/doutput[i] is 0.
591          for i, out_grad in enumerate(out_grads):
592            if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
593                (not grad_fn and is_func_call) or _IsTrainable(op.outputs[i])):
594              # Only trainable outputs or outputs for a function call that
595              # will use SymbolicGradient get a zero gradient. Gradient
596              # functions should ignore the gradient for other outputs.
597              # TODO(apassos) gradients of resource handles might be an
598              # issue here because of zeros.
599              if loop_state:
600                out_grads[i] = loop_state.ZerosLike(op, i)
601              else:
602                out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
603          with ops.name_scope(op.name + "_grad"):
604            # pylint: disable=protected-access
605            with ops.get_default_graph()._original_op(op):
606              # pylint: enable=protected-access
607              if grad_fn:
608                # If grad_fn was found, do not use SymbolicGradient even for
609                # functions.
610                in_grads = _MaybeCompile(grad_scope, op, func_call,
611                                         lambda: grad_fn(op, *out_grads))
612              else:
613                # For function call ops, we add a 'SymbolicGradient'
614                # node to the graph to compute gradients.
615                in_grads = _MaybeCompile(grad_scope, op, func_call,
616                                         lambda: _SymGrad(op, out_grads))
617              in_grads = _AsList(in_grads)
618              _VerifyGeneratedGradients(in_grads, op)
619              if gate_gradients and len([x for x in in_grads
620                                         if x is not None]) > 1:
621                with ops.device(None):
622                  with ops.colocate_with(None, ignore_existing=True):
623                    in_grads = control_flow_ops.tuple(in_grads)
624          _LogOpGradients(op, out_grads, in_grads)
625        else:
626          # If no grad_fn is defined or none of out_grads is available,
627          # just propagate a list of None backwards.
628          in_grads = [None] * len(op.inputs)
629        for i, (t_in, in_grad) in enumerate(zip(op.inputs, in_grads)):
630          if in_grad is not None:
631            if (isinstance(in_grad, ops.Tensor) and
632                t_in.dtype != dtypes.resource):
633              try:
634                in_grad.set_shape(t_in.get_shape())
635              except ValueError:
636                raise ValueError(
637                    "Incompatible shapes between op input and calculated "
638                    "input gradient.  Forward operation: %s.  Input index: %d. "
639                    "Original input shape: %s.  "
640                    "Calculated input gradient shape: %s" %
641                    (op.name, i, t_in.shape, in_grad.shape))
642            _SetGrad(grads, t_in, in_grad)
643        if loop_state:
644          loop_state.ExitGradWhileContext(op, before=False)
645
646      # Update pending count for the inputs of op and enqueue ready ops.
647      _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state)
648
649  if loop_state:
650    loop_state.PostProcessing()
651  return [_GetGrad(grads, x) for x in xs]
652
653
654def _HasAnyNotNoneGrads(grads, op):
655  """Return true iff op has real gradient."""
656  out_grads = _GetGrads(grads, op)
657  for out_grad in out_grads:
658    if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
659      return True
660    if out_grad and isinstance(out_grad, collections.Sequence):
661      if any([g is not None for g in out_grad]):
662        return True
663  return False
664
665
666def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
667  """Update pending count for the inputs of op and enqueue ready ops."""
668  for x in op.inputs:
669    # pylint: disable=protected-access
670    pending_count[x.op._id] -= 1
671    ready = (pending_count[x.op._id] == 0)
672    if loop_state and not ready:
673      ready = (
674          pending_count[x.op._id] > 0 and control_flow_util.IsLoopSwitch(x.op))
675    # pylint: enable=protected-access
676    if ready:
677      if control_flow_util.IsLoopExit(x.op):
678        # if x is an exit without real gradient, defer processing them.
679        grad_state = loop_state.GetGradState(x.op, before=False)
680        grad_state.deferred_exits.append(x)
681        grad_state.pending_exits_count -= 1
682        if grad_state.pending_exits_count == 0:
683          # We now have all the exits so process them.
684          has_not_none_grad = False
685          for y in grad_state.deferred_exits:
686            if _HasAnyNotNoneGrads(grads, y.op):
687              has_not_none_grad = True
688              queue.append(y.op)
689            else:
690              grad_state.unused_exits.append(y)
691          if has_not_none_grad:
692            # For an unused exit, if it has trainable outputs, backprop
693            # a zero gradient. Otherwise, just ignore it.
694            for y in grad_state.unused_exits:
695              if _IsTrainable(y):
696                _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
697              queue.append(y.op)
698          else:
699            # All exits are "unused" so use None as gradient.
700            for y in grad_state.unused_exits:
701              queue.append(y.op)
702      else:
703        queue.append(x.op)
704
705
706def _SetGrad(grads, t, grad):
707  """Sets gradient "grad" in "grads" for tensor "t"."""
708  op = t.op
709  op_grads = grads.get(op)
710  if not op_grads:
711    op_grads = [[] for _ in xrange(len(op.outputs))]
712    grads[op] = op_grads
713  t_grads = op_grads[t.value_index]
714  if isinstance(t_grads, list):
715    t_grads.append(grad)
716  else:
717    assert control_flow_util.IsLoopSwitch(op)
718    op_grads[t.value_index] = grad
719
720
721def _GetGrad(grads, t):
722  """Gets gradient for tensor "t"."""
723  op = t.op
724  op_grads = grads.get(op)
725  if not op_grads:
726    return None
727  t_grad = op_grads[t.value_index]
728  assert not isinstance(
729      t_grad, list), ("gradients list should have been aggregated by now.")
730  return t_grad
731
732
733def _GetGrads(grads, op):
734  """Gets all gradients for op."""
735  if op in grads:
736    return grads[op]
737  else:
738    return [[] for _ in xrange(len(op.outputs))]
739
740
741def _HandleNestedIndexedSlices(grad):
742  assert isinstance(grad, ops.IndexedSlices)
743  if isinstance(grad.values, ops.Tensor):
744    return grad
745  else:
746    assert isinstance(grad.values, ops.IndexedSlices)
747    g = _HandleNestedIndexedSlices(grad.values)
748    return ops.IndexedSlices(g.values, array_ops.gather(
749        grad.indices, g.indices), g.dense_shape)
750
751
752def _AccumulatorShape(inputs):
753  shape = tensor_shape.unknown_shape()
754  for i in inputs:
755    if isinstance(i, ops.Tensor):
756      shape = shape.merge_with(i.get_shape())
757  return shape
758
759
760def _LogOpGradients(op, out_grads, in_grads):
761  """Log the in and out grads of an op."""
762  logging.vlog(1, "Gradient for '" + op.name + "'")
763
764  def _FilterGrad(x):
765    if x is None:
766      return False
767    if isinstance(x, (list, tuple)):
768      return bool(x)
769    else:
770      return True
771
772  logging.vlog(1, "  in  --> %s",
773               ", ".join([x.name for x in out_grads if _FilterGrad(x)]))
774  logging.vlog(1, "  out --> %s",
775               ", ".join([x.name for x in in_grads if _FilterGrad(x)]))
776
777
778def _MultiDeviceAddN(tensor_list):
779  """Adds tensors from potentially multiple devices."""
780  # Basic function structure comes from control_flow_ops.group().
781  # Sort tensors according to their devices.
782  tensors_on_device = collections.defaultdict(lambda: [])
783  for tensor in tensor_list:
784    tensors_on_device[tensor.device].append(tensor)
785
786  # For each device, add the tensors on that device first.
787  # Then gather the partial sums from multiple devices.
788  # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion.
789  # E.g., aggregate per GPU, then per task, and so on.
790  summands = []
791
792  def DeviceKey(dev):
793    return "" if dev is None else dev
794
795  for dev in sorted(six.iterkeys(tensors_on_device), key=DeviceKey):
796    tensors = tensors_on_device[dev]
797    with ops.colocate_with(tensors[0].op, ignore_existing=True):
798      summands.append(math_ops.add_n(tensors))
799
800  return math_ops.add_n(summands)
801
802
803@tf_export("AggregationMethod")
804class AggregationMethod(object):
805  """A class listing aggregation methods used to combine gradients.
806
807  Computing partial derivatives can require aggregating gradient
808  contributions. This class lists the various methods that can
809  be used to combine gradients in the graph:
810
811  *  `ADD_N`: All of the gradient terms are summed as part of one
812     operation using the "AddN" op. It has the property that all
813     gradients must be ready before any aggregation is performed.
814  *  `DEFAULT`: The system-chosen default aggregation method.
815  """
816  ADD_N = 0
817  DEFAULT = ADD_N
818  # The following are experimental and may not be supported in future releases.
819  EXPERIMENTAL_TREE = 1
820  EXPERIMENTAL_ACCUMULATE_N = 2
821
822
823def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
824  """Get the aggregated gradients for op.
825
826  Args:
827    grads: The map of memoized gradients.
828    op: The op to get gradients for.
829    loop_state: An object for maintaining the state of the while loops in the
830                graph. It is of type ControlFlowState. None if the graph
831                contains no while loops.
832    aggregation_method: Specifies the method used to combine gradient terms.
833      Accepted values are constants defined in the class `AggregationMethod`.
834
835  Returns:
836    A list of gradients, one per each output of `op`. If the gradients
837      for a particular output is a list, this function aggregates it
838      before returning.
839
840  Raises:
841    TypeError: if the incoming grads are not Tensors or IndexedSlices.
842    ValueError: if the arguments are invalid.
843
844  """
845  if aggregation_method is None:
846    aggregation_method = AggregationMethod.DEFAULT
847  if aggregation_method not in [
848      AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE,
849      AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
850  ]:
851    raise ValueError(
852        "Invalid aggregation_method specified %s." % aggregation_method)
853  out_grads = _GetGrads(grads, op)
854  for i, out_grad in enumerate(out_grads):
855    if loop_state:
856      if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
857        assert control_flow_util.IsLoopSwitch(op)
858        continue
859    # Grads have to be Tensors or IndexedSlices
860    if (isinstance(out_grad, collections.Sequence) and not all([
861        isinstance(g, (ops.Tensor, ops.IndexedSlices))
862        for g in out_grad
863        if g is not None
864    ])):
865      raise TypeError("gradients have to be either all Tensors "
866                      "or all IndexedSlices")
867    # Aggregate multiple gradients, and convert [] to None.
868    if out_grad:
869      if len(out_grad) < 2:
870        used = "nop"
871        out_grads[i] = out_grad[0]
872      elif all([isinstance(g, ops.Tensor) for g in out_grad if g is not None]):
873        tensor_shape = _AccumulatorShape(out_grad)
874        if (aggregation_method == AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
875            and len(out_grad) > 2 and tensor_shape.is_fully_defined()):
876          # The benefit of using AccumulateN is that its inputs can be combined
877          # in any order and this can allow the expression to be evaluated with
878          # a smaller memory footprint.  When used with gpu_allocator_retry,
879          # it is possible to compute a sum of terms which are much larger than
880          # total GPU memory.
881          # AccumulateN can currently only be used if we know the shape for
882          # an accumulator variable.  If this is not known, or if we only have
883          # 2 grads then we fall through to the "tree" case below.
884          used = "accumulate_n"
885          out_grads[i] = math_ops.accumulate_n(out_grad)
886        elif aggregation_method in [
887            AggregationMethod.EXPERIMENTAL_TREE,
888            AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
889        ]:
890          # Aggregate all gradients by doing pairwise sums: this may
891          # reduce performance, but it can improve memory because the
892          # gradients can be released earlier.
893          #
894          # TODO(vrv): Consider replacing this with a version of
895          # tf.AddN() that eagerly frees its inputs as soon as they are
896          # ready, so the order of this tree does not become a problem.
897          used = "tree"
898          with ops.name_scope(op.name + "_gradient_sum"):
899            running_sum = out_grad[0]
900            for grad in out_grad[1:]:
901              running_sum = math_ops.add_n([running_sum, grad])
902            out_grads[i] = running_sum
903        else:
904          used = "add_n"
905          out_grads[i] = _MultiDeviceAddN(out_grad)
906        logging.vlog(2, "  _AggregatedGrads %d x %s using %s", len(out_grad),
907                     tensor_shape, used)
908      else:
909        out_grad = math_ops._as_indexed_slices_list(
910            [g for g in out_grad if g is not None])
911        out_grad = [_HandleNestedIndexedSlices(x) for x in out_grad]
912        # Form IndexedSlices out of the concatenated values and
913        # indices.
914        out_grads[i] = ops.IndexedSlices(
915            array_ops.concat([x.values for x in out_grad], 0),
916            array_ops.concat([x.indices for x in out_grad], 0),
917            out_grad[0].dense_shape)
918    else:  # not out_grad
919      # out_grads[i] is [], thus its aggregation is simply None.
920      out_grads[i] = None
921  return out_grads
922
923
924# TODO(vrv): Make this available when we want to make it public.
925def _hessian_vector_product(ys, xs, v):
926  """Multiply the Hessian of `ys` wrt `xs` by `v`.
927
928  This is an efficient construction that uses a backprop-like approach
929  to compute the product between the Hessian and another vector. The
930  Hessian is usually too large to be explicitly computed or even
931  represented, but this method allows us to at least multiply by it
932  for the same big-O cost as backprop.
933
934  Implicit Hessian-vector products are the main practical, scalable way
935  of using second derivatives with neural networks. They allow us to
936  do things like construct Krylov subspaces and approximate conjugate
937  gradient descent.
938
939  Example: if `y` = 1/2 `x`^T A `x`, then `hessian_vector_product(y,
940  x, v)` will return an expression that evaluates to the same values
941  as (A + A.T) `v`.
942
943  Args:
944    ys: A scalar value, or a tensor or list of tensors to be summed to
945        yield a scalar.
946    xs: A list of tensors that we should construct the Hessian over.
947    v: A list of tensors, with the same shapes as xs, that we want to
948       multiply by the Hessian.
949
950  Returns:
951    A list of tensors (or if the list would be length 1, a single tensor)
952    containing the product between the Hessian and `v`.
953
954  Raises:
955    ValueError: `xs` and `v` have different length.
956
957  """
958
959  # Validate the input
960  length = len(xs)
961  if len(v) != length:
962    raise ValueError("xs and v must have the same length.")
963
964  # First backprop
965  grads = gradients(ys, xs)
966
967  assert len(grads) == length
968  elemwise_products = [
969      math_ops.multiply(grad_elem, array_ops.stop_gradient(v_elem))
970      for grad_elem, v_elem in zip(grads, v)
971      if grad_elem is not None
972  ]
973
974  # Second backprop
975  return gradients(elemwise_products, xs)
976
977
978@tf_export("hessians")
979def hessians(ys,
980             xs,
981             name="hessians",
982             colocate_gradients_with_ops=False,
983             gate_gradients=False,
984             aggregation_method=None):
985  """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`.
986
987  `hessians()` adds ops to the graph to output the Hessian matrix of `ys`
988  with respect to `xs`.  It returns a list of `Tensor` of length `len(xs)`
989  where each tensor is the Hessian of `sum(ys)`.
990
991  The Hessian is a matrix of second-order partial derivatives of a scalar
992  tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details).
993
994  Args:
995    ys: A `Tensor` or list of tensors to be differentiated.
996    xs: A `Tensor` or list of tensors to be used for differentiation.
997    name: Optional name to use for grouping all the gradient ops together.
998      defaults to 'hessians'.
999    colocate_gradients_with_ops: See `gradients()` documentation for details.
1000    gate_gradients: See `gradients()` documentation for details.
1001    aggregation_method: See `gradients()` documentation for details.
1002
1003  Returns:
1004    A list of Hessian matrices of `sum(ys)` for each `x` in `xs`.
1005
1006  Raises:
1007    LookupError: if one of the operations between `xs` and `ys` does not
1008      have a registered gradient function.
1009  """
1010  xs = _AsList(xs)
1011  kwargs = {
1012      "colocate_gradients_with_ops": colocate_gradients_with_ops,
1013      "gate_gradients": gate_gradients,
1014      "aggregation_method": aggregation_method
1015  }
1016  # Compute first-order derivatives and iterate for each x in xs.
1017  hessians = []
1018  _gradients = gradients(ys, xs, **kwargs)
1019  for gradient, x in zip(_gradients, xs):
1020    # change shape to one-dimension without graph branching
1021    gradient = array_ops.reshape(gradient, [-1])
1022
1023    # Declare an iterator and tensor array loop variables for the gradients.
1024    n = array_ops.size(x)
1025    loop_vars = [
1026        array_ops.constant(0, dtypes.int32),
1027        tensor_array_ops.TensorArray(x.dtype, n)
1028    ]
1029    # Iterate over all elements of the gradient and compute second order
1030    # derivatives.
1031    _, hessian = control_flow_ops.while_loop(
1032        lambda j, _: j < n,
1033        lambda j, result: (j + 1,
1034                           result.write(j, gradients(gradient[j], x)[0])),
1035        loop_vars
1036    )
1037
1038    _shape = array_ops.shape(x)
1039    _reshaped_hessian = array_ops.reshape(hessian.stack(),
1040                                          array_ops.concat((_shape, _shape), 0))
1041    hessians.append(_reshaped_hessian)
1042  return hessians
1043