• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements the graph generation for computation of gradients."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import warnings
24
25import numpy as np
26from six.moves import xrange  # pylint: disable=redefined-builtin
27
28from tensorflow.core.framework import attr_value_pb2
29from tensorflow.python.eager import context
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import function as framework_function
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.framework.func_graph import FuncGraph
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_ops
39from tensorflow.python.ops import control_flow_util
40from tensorflow.python.ops import functional_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import resource_variable_ops
43from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import compat
46from tensorflow.python.util.tf_export import tf_export
47
48
49# Warn the user if we convert a sparse representation to dense with at
50# least this number of elements.
51_LARGE_SPARSE_NUM_ELEMENTS = 100000000
52
53
54def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False):
55  """Converts an IndexedSlices object `value` to a Tensor.
56
57  NOTE(mrry): This function is potentially expensive.
58
59  Args:
60    value: An ops.IndexedSlices object.
61    dtype: The dtype of the Tensor to be returned.
62    name: Optional name to use for the returned Tensor.
63    as_ref: True if a ref is requested.
64
65  Returns:
66    A dense Tensor representing the values in the given IndexedSlices.
67
68  Raises:
69    ValueError: If the IndexedSlices does not have the same dtype.
70  """
71  _ = as_ref
72  if dtype and not dtype.is_compatible_with(value.dtype):
73    raise ValueError(
74        "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" %
75        (dtype.name, value.dtype.name))
76  if value.dense_shape is None:
77    raise ValueError(
78        "Tensor conversion requested for IndexedSlices without dense_shape: %s"
79        % str(value))
80  # TODO(mrry): Consider adding static shape information to
81  # IndexedSlices, to avoid using numpy here.
82  if not context.executing_eagerly():
83    dense_shape_value = tensor_util.constant_value(value.dense_shape)
84    if dense_shape_value is not None:
85      num_elements = np.prod(dense_shape_value)
86      if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS:
87        warnings.warn(
88            "Converting sparse IndexedSlices to a dense Tensor with %d "
89            "elements. This may consume a large amount of memory." %
90            num_elements)
91    else:
92      warnings.warn(
93          "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
94          "This may consume a large amount of memory.")
95  return math_ops.unsorted_segment_sum(
96      value.values, value.indices, value.dense_shape[0], name=name)
97
98
99ops.register_tensor_conversion_function(ops.IndexedSlices,
100                                        _IndexedSlicesToTensor)
101
102
103def _MarkReachedOps(from_ops, reached_ops, func_graphs):
104  """Mark all ops reached from "from_ops".
105
106  Args:
107    from_ops: list of Operations.
108    reached_ops: set of Operations.
109    func_graphs: list of FuncGraphs. This method will traverse through
110      these functions if they capture from_ops or any reachable ops.
111  """
112  queue = collections.deque()
113  queue.extend(from_ops)
114  while queue:
115    op = queue.popleft()
116    if op not in reached_ops:
117      reached_ops.add(op)
118      for output in op.outputs:
119        if _IsBackpropagatable(output):
120          queue.extend(_Consumers(output, func_graphs))
121
122
123def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
124                  xs):
125  """Initialize the pending count for ops between two lists of Operations.
126
127  'pending_count[op]' indicates the number of backprop inputs
128  to this operation.
129
130  Args:
131    to_ops: list of Operations.
132    from_ops: list of Operations.
133    colocate_gradients_with_ops: Python bool.  See docstring of gradients().
134    func_graphs: list of FuncGraphs. This method will traverse through
135      these functions if they capture from_ops or any reachable ops. This is
136      useful if to_ops occur in a function and from_ops are in an outer function
137      or graph.
138    xs: list of Tensors.
139
140  Returns:
141    A tuple containing: (1) the subset of to_ops reachable from from_ops by a
142    path of zero or more backpropagatable tensors, (2) a mapping from operation
143    to the number of backprop inputs to that op, and (3) a ControlFlowState
144    object which is not None if the ops between from_ops and to_ops contain
145    control flow loops.
146  """
147  # Mark reachable ops from from_ops.
148  reached_ops = set()
149  _MarkReachedOps(from_ops, reached_ops, func_graphs)
150  # X in reached_ops iff X is reachable from from_ops by a path of zero or more
151  # backpropagatable tensors.
152
153  reachable_to_ops = set(op for op in to_ops if op in reached_ops)
154
155  # Mark between ops.
156  between_ops = set()
157  between_op_list = []
158  queue = collections.deque()
159  queue.extend(to_ops)
160  while queue:
161    op = queue.popleft()
162    # We are interested in this op.
163    if op in reached_ops:
164      between_ops.add(op)
165      between_op_list.append(op)
166      # Clear the boolean so we won't add the inputs again.
167      reached_ops.remove(op)
168      for inp in _NonEagerInputs(op, xs):
169        queue.append(inp.op)
170  # X in between_ops iff X is on a path of zero or more backpropagatable tensors
171  # between from_ops and to_ops
172
173  # 'loop_state' is None if there are no while loops.
174  loop_state = control_flow_ops.MaybeCreateControlFlowState(
175      between_op_list, between_ops, colocate_gradients_with_ops)
176
177  # Initialize pending count for between ops.
178  pending_count = collections.defaultdict(int)
179  for op in between_op_list:
180    for x in _NonEagerInputs(op, xs):
181      if x.op in between_ops:
182        pending_count[x.op] += 1
183
184  return reachable_to_ops, pending_count, loop_state
185
186
187def _AsList(x):
188  return x if isinstance(x, (list, tuple)) else [x]
189
190
191def _DefaultGradYs(grad_ys,
192                   ys,
193                   colocate_gradients_with_ops,
194                   gradient_uid="__unsupported__"):
195  """Fill in default values for grad_ys.
196
197  Args:
198    grad_ys: List of gradients, can contain None.
199    ys: List of tensors.
200    colocate_gradients_with_ops: If True, try colocating gradients with
201      the corresponding op.
202    gradient_uid: A unique identifier within the graph indicating
203      which invocation of gradients is being executed. Used to cluster
204      ops for compilation.
205
206  Returns:
207    A list of gradients to use, without None.
208
209  Raises:
210    ValueError: If sizes of gradients and inputs don't match
211    TypeError: If type of any gradient is not valid for its input.
212  """
213  if len(grad_ys) != len(ys):
214    raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys)))
215  grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y")
216  new_grad_ys = []
217  for i in xrange(len(grad_ys)):
218    grad_y = grad_ys[i]
219    y = ys[i]
220    with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops):
221      if grad_y is None:
222        if y.dtype.is_complex:
223          raise TypeError(
224              "Gradients of complex tensors must set grad_ys (y.dtype = %r)" %
225              y.dtype)
226        new_grad_ys.append(
227            array_ops.fill(
228                array_ops.shape(y),
229                constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i)))
230        continue
231      if y.dtype.is_floating or y.dtype.is_integer:
232        if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
233          raise TypeError(
234              "Gradient type %s generated for real or "
235              "integer-valued tensor %s with type %s must be "
236              "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y,
237                                   dtypes.as_dtype(y.dtype).name))
238      elif y.dtype.is_complex:
239        if not grad_y.dtype.is_complex:
240          raise TypeError(
241              "Gradient type %s generated for complex-valued "
242              "tensor %s with type %s must be real" % (dtypes.as_dtype(
243                  grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
244      elif y.dtype == dtypes.variant:
245        if grad_y.dtype != dtypes.variant:
246          raise TypeError(
247              "Gradient type %s generated for variant "
248              "tensor %s with type %s must be variant" % (dtypes.as_dtype(
249                  grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
250      elif y.dtype == dtypes.resource:
251        # We assume y is the handle of a ResourceVariable. The gradient of a
252        # ResourceVariable should be a numeric value, not another resource.
253        if grad_y.dtype == dtypes.resource:
254          raise TypeError("Input gradient %s for resource tensor %s should not "
255                          "be a resource" % (grad_y, y))
256      else:
257        raise TypeError(
258            "Tensor %s with type %s must be numeric "
259            "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name))
260      # Create a grad_y tensor in the name scope of the gradient.
261      # Required for TensorArrays to identify which gradient call a
262      # grad_y value is coming from.
263      if isinstance(grad_y, ops.IndexedSlices):
264        new_grad_ys.append(
265            ops.IndexedSlices(
266                indices=(array_ops.identity(
267                    grad_y.indices, name="grad_ys_%d_indices" % i)
268                         if isinstance(grad_y.indices, ops.Tensor) else
269                         grad_y.indices),
270                values=(array_ops.identity(
271                    grad_y.values, name="grad_ys_%d_values" % i) if isinstance(
272                        grad_y.values, ops.Tensor) else grad_y.values),
273                dense_shape=(array_ops.identity(
274                    grad_y.dense_shape, name="grad_ys_%d_shape" % i)
275                             if isinstance(grad_y.dense_shape, ops.Tensor) else
276                             grad_y.dense_shape)))
277      else:
278        new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i))
279
280  return new_grad_ys
281
282
283def IsTrainable(tensor_or_dtype):
284  if isinstance(tensor_or_dtype, ops.Tensor):
285    dtype = tensor_or_dtype.dtype
286  else:
287    dtype = tensor_or_dtype
288  dtype = dtypes.as_dtype(dtype)
289  return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
290                              dtypes.complex64, dtypes.complex128,
291                              dtypes.resource, dtypes.variant)
292
293
294def _IsBackpropagatable(tensor):
295  if IsTrainable(tensor):
296    return True
297  dtype = dtypes.as_dtype(tensor.dtype)
298  return dtype.base_dtype == dtypes.bfloat16
299
300
301def _VerifyGeneratedGradients(grads, op):
302  """Verify that gradients are valid in number and type.
303
304  Args:
305    grads: List of generated gradients.
306    op: Operation for which the gradients where generated.
307
308  Raises:
309    ValueError: if sizes of gradients and inputs don't match.
310    TypeError: if type of any gradient is not valid for its input.
311  """
312  # While ops have inputs added to them during the gradient computation, so we
313  # skip the below check. See while_v2 for details.
314  if op.type == "While": return
315
316  if len(grads) != len(op.inputs):
317    raise ValueError("Num gradients %d generated for op %s do not match num "
318                     "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
319
320
321def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
322  """The set of ops that terminate the gradient computation.
323
324  This computes the frontier of the forward graph *before* which backprop
325  should stop. Operations in the returned set will not be differentiated.
326  This set is defined as the subset of `from_ops` containing ops that have
327  no predecessor in `from_ops`. `pending_count` is the result of
328  `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops`
329  iff pending_count[op] > 0.
330
331  In addition, none of `stop_gradient_ops` will be differentiated.
332
333  Args:
334    from_ops: list of Operations.
335    stop_gradient_ops: list of Operations never to backprop through.
336    pending_count: mapping from operation to number of backprop inputs.
337    xs: list of Tensors.
338
339  Returns:
340    The set of operations.
341  """
342  stop_ops = set()
343  for op in from_ops:
344    is_stop_op = True
345    for inp in _NonEagerInputs(op, xs):
346      if pending_count[inp.op] > 0:
347        is_stop_op = False
348        break
349    if is_stop_op:
350      stop_ops.add(op)
351  stop_ops.update(op for op in stop_gradient_ops)
352  return stop_ops
353
354
355@contextlib.contextmanager
356def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):  # pylint: disable=invalid-name
357  """Context to colocate with `op` if `colocate_gradients_with_ops`."""
358  if colocate_gradients_with_ops:
359    with ops._colocate_with_for_gradient(op, gradient_uid):  # pylint: disable=protected-access
360      yield
361  else:
362    yield
363
364
365def _IsPartitionedCall(op):
366  return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall"
367
368
369def _SymGrad(op, out_grads):
370  """Backprop through a function call node op given its outputs' gradients."""
371  f_in = [x for x in op.inputs] + out_grads
372  f_types = [x.dtype for x in op.inputs]
373  f = attr_value_pb2.NameAttrList()
374  if _IsPartitionedCall(op):
375    f.name = op.get_attr("f").name
376  else:
377    f.name = op.type
378  for k in op.node_def.attr:
379    f.attr[k].CopyFrom(op.node_def.attr[k])
380  # TODO(apassos) use a better dtype here
381  in_grads = functional_ops.symbolic_gradient(
382      input=f_in,
383      Tout=[x if x != dtypes.resource else dtypes.float32 for x in f_types],
384      f=f)
385  return in_grads
386
387
388def _MaybeCompile(scope, op, func, grad_fn):
389  """Compile the calculation in grad_fn if op was marked as compiled."""
390  scope = scope.rstrip("/").replace("/", "_")
391  if func is not None:
392    xla_compile = func.definition.attr["_XlaCompile"].b
393    xla_separate_compiled_gradients = func.definition.attr[
394        "_XlaSeparateCompiledGradients"].b
395    xla_scope = func.definition.attr["_XlaScope"].s.decode()
396  else:
397    try:
398      xla_compile = op.get_attr("_XlaCompile")
399      xla_separate_compiled_gradients = op.get_attr(
400          "_XlaSeparateCompiledGradients")
401      xla_scope = op.get_attr("_XlaScope").decode()
402    except ValueError:
403      return grad_fn()  # Exit early
404
405  if not xla_compile:
406    return grad_fn()  # Exit early
407
408  # If the gradients are supposed to be compiled separately, we give them a
409  # _XlaScope name that is based on the name_scope of the gradients.  Otherwise
410  # they just inherit the existing _XlaScope name, which lets them be merged
411  # together with the non-gradient computation.
412  if xla_separate_compiled_gradients:
413    xla_grad_scope = "%s_grad_%s" % (xla_scope, scope)
414  else:
415    xla_grad_scope = xla_scope
416
417  attrs = {
418      "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile),
419      "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode())
420  }
421  with ops.get_default_graph()._attr_scope(attrs):  # pylint: disable=protected-access
422    return grad_fn()
423
424
425def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
426  """Raises an error if we backprop through a loop var."""
427  # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
428  # message.
429  target_op = None
430  queue = collections.deque([op])
431  visited = set()
432  while queue:
433    curr_op = queue.popleft()
434    if curr_op in visited: continue
435    visited.add(curr_op)
436    if curr_op in from_ops:
437      target_op = curr_op
438      break
439    queue.extend(t.op for t in _NonEagerInputs(curr_op, xs))
440  assert target_op
441  raise ValueError(
442      "Cannot compute gradient inside while loop with respect to op '%s'. "
443      "We do not support taking the gradient wrt or through the initial value "
444      "of a loop variable. Gradients can be computed through loop invariants "
445      "or wrt the input parameters to the loop body."
446      % target_op.name)
447
448
449def _IsFunction(graph):
450  return (isinstance(graph, FuncGraph) or
451          isinstance(graph, framework_function._FuncGraph))  # pylint: disable=protected-access
452
453
454def _Captures(func_graph):
455  if isinstance(func_graph, FuncGraph):
456    return func_graph.captures
457  else:
458    assert isinstance(func_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
459    return func_graph._captured  # pylint: disable=protected-access
460
461
462def _MaybeCaptured(t):
463  """If t is a captured value placeholder, returns the original captured value.
464
465  Args:
466    t: Tensor
467
468  Returns:
469    A tensor, potentially from a different Graph/FuncGraph.
470  """
471  # pylint: disable=protected-access
472  if (not isinstance(t, ops.EagerTensor) and
473      _IsFunction(t.op.graph) and t.op.type == "Placeholder"):
474    for input_t, placeholder_t in _Captures(t.op.graph).items():
475      if t == placeholder_t:
476        return _MaybeCaptured(input_t)
477  # pylint: enable=protected-access
478  return t
479
480
481def _NonEagerInputs(op, xs):
482  """Returns the inputs of op, crossing closure boundaries where necessary.
483
484  Does not return any captured EagerTensors, i.e., the number of tensors
485  returned may be less than than the actual number of inputs.
486
487  Args:
488    op: Operation
489    xs: list of Tensors we are differentiating w.r.t.
490
491  Returns:
492    A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
493    is in a FuncGraph and has captured inputs.
494  """
495  return [t for t in _Inputs(op, xs) if not isinstance(t, ops.EagerTensor)]
496
497
498# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
499# _GradientsHelper a class with xs as a member variable.
500def _Inputs(op, xs):
501  """Returns the inputs of op, crossing closure boundaries where necessary.
502
503  Args:
504    op: Operation
505    xs: list of Tensors we are differentiating w.r.t.
506
507  Returns:
508    A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
509    is in a FuncGraph and has captured inputs.
510  """
511  if _IsFunction(op.graph):  # pylint: disable=protected-access
512    inputs = []
513    for t in op.inputs:
514      # If we're differentiating w.r.t. `t`, do not attempt to traverse through
515      # it to a captured value. The algorithm needs to "see" `t` in this case,
516      # even if it's a function input for a captured value, whereas usually we'd
517      # like to traverse through these closures as if the captured value was the
518      # direct input to op.
519      if t not in xs:
520        t = _MaybeCaptured(t)
521      inputs.append(t)
522    return inputs
523  else:
524    return op.inputs
525
526
527def _Consumers(t, func_graphs):
528  """Returns the consumers of t, crossing closure boundaries where necessary.
529
530  Args:
531    t: Tensor
532    func_graphs: a list of FuncGraphs that may have captured t.
533
534  Returns:
535    A list of tensors. The tensors will be from the current graph and/or
536    func_graphs.
537  """
538  consumers = t.consumers()
539  for func in func_graphs:
540    for input_t, placeholder in _Captures(func).items():
541      if input_t == t:
542        consumers.extend(_Consumers(placeholder, func_graphs))
543  return consumers
544
545
546def _GradientsHelper(ys,
547                     xs,
548                     grad_ys=None,
549                     name="gradients",
550                     colocate_gradients_with_ops=False,
551                     gate_gradients=False,
552                     aggregation_method=None,
553                     stop_gradients=None,
554                     unconnected_gradients=UnconnectedGradients.NONE,
555                     src_graph=None):
556  """Implementation of gradients()."""
557  if context.executing_eagerly():
558    raise RuntimeError("tf.gradients is not supported when eager execution "
559                       "is enabled. Use tf.GradientTape instead.")
560  if src_graph is None:
561    src_graph = ops.get_default_graph()
562  try:
563    unconnected_gradients = UnconnectedGradients(unconnected_gradients)
564  except ValueError:
565    raise ValueError(
566        "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
567
568  # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
569  # ancestor graphs. This is necessary for correctly handling captured values.
570  func_graphs = []
571  curr_graph = src_graph
572  while _IsFunction(curr_graph):
573    func_graphs.append(curr_graph)
574    if isinstance(curr_graph, FuncGraph):
575      curr_graph = curr_graph.outer_graph
576    else:
577      assert isinstance(curr_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
578      curr_graph = curr_graph._outer_graph  # pylint: disable=protected-access
579
580  ys = _AsList(ys)
581  xs = _AsList(xs)
582  stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
583  if grad_ys is None:
584    grad_ys = [None] * len(ys)
585  else:
586    grad_ys = _AsList(grad_ys)
587
588  with ops.name_scope(
589      name, "gradients",
590      list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
591    # Get a uid for this call to gradients that can be used to help
592    # cluster ops for compilation.
593    gradient_uid = ops.get_default_graph().unique_name("uid")
594    ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
595    xs = [
596        x.handle if resource_variable_ops.is_resource_variable(x) else x
597        for x in xs
598    ]
599    xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
600        xs, name="x", as_ref=True)
601    grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
602                             gradient_uid)
603
604    # The approach we take here is as follows: Create a list of all ops in the
605    # subgraph between the ys and xs.  Visit these ops in reverse order of ids
606    # to ensure that when we visit an op the gradients w.r.t its outputs have
607    # been collected.  Then aggregate these gradients if needed, call the op's
608    # gradient function, and add the generated gradients to the gradients for
609    # its input.
610
611    # Initialize the pending count for ops in the connected subgraph from ys
612    # to the xs.
613    to_ops = [t.op for t in ys]
614    from_ops = [t.op for t in xs]
615    stop_gradient_ops = [t.op for t in stop_gradients]
616    reachable_to_ops, pending_count, loop_state = _PendingCount(
617        to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs)
618
619    # Iterate over the collected ops.
620    #
621    # grads: op => list of gradients received on each output endpoint of the
622    # op.  The gradients for each endpoint are initially collected as a list.
623    # When it is time to call the op's gradient function, for each endpoint we
624    # aggregate the list of received gradients into a Add() Operation if there
625    # is more than one.
626    grads = {}
627
628    # Add the initial gradients for the ys.
629    for y, grad_y in zip(ys, grad_ys):
630      _SetGrad(grads, y, grad_y)
631
632    # Initialize queue with to_ops.
633    queue = collections.deque()
634    # Add the ops in 'to_ops' into the queue.
635    to_ops_set = set()
636    for op in to_ops:
637      # 'ready' handles the case where one output gradient relies on
638      # another output's gradient.
639      ready = (pending_count[op] == 0)
640      if ready and op not in to_ops_set and op in reachable_to_ops:
641        to_ops_set.add(op)
642        queue.append(op)
643
644    if loop_state:
645      loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
646      for y in loop_exits:
647        if IsTrainable(y):
648          _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
649          queue.append(y.op)
650
651    stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs)
652    while queue:
653      # generate gradient subgraph for op.
654      op = queue.popleft()
655      with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
656        if loop_state:
657          loop_state.EnterGradWhileContext(op, before=True)
658        out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
659                                     aggregation_method)
660        if loop_state:
661          loop_state.ExitGradWhileContext(op, before=True)
662
663        grad_fn = None
664        func_call = None
665        is_partitioned_call = _IsPartitionedCall(op)
666        # pylint: disable=protected-access
667        is_func_call = (
668            src_graph._is_function(op.type) or is_partitioned_call)
669        # pylint: enable=protected-access
670        has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
671        if has_out_grads and (op not in stop_ops):
672          try:
673            grad_fn = ops.get_gradient_function(op)
674          except LookupError:
675            if is_func_call:
676              if is_partitioned_call:
677                func_call = src_graph._get_function(  # pylint: disable=protected-access
678                    compat.as_bytes(op.get_attr("f").name))
679              else:
680                func_call = src_graph._get_function(op.type)  # pylint: disable=protected-access
681              # Note that __defun is not set if the graph is
682              # imported. If it's set, we prefer to access the original
683              # defun.
684              func_call = getattr(op, "__defun", func_call)
685              grad_fn = func_call.python_grad_func
686            else:
687              raise LookupError(
688                  "No gradient defined for operation '%s' (op type: %s)" %
689                  (op.name, op.type))
690        if loop_state:
691          loop_state.EnterGradWhileContext(op, before=False)
692
693        # NOTE(skyewm): We don't support computing gradients wrt a loop variable
694        # unless it's within the context of a single iteration (i.e. the
695        # gradient is wrt to the loop parameter in the body function, not wrt or
696        # through the initial value). This means if we're in a while loop
697        # context, we should never see a switch node from this context.
698        # pylint: disable=protected-access
699        if (control_flow_util.IsSwitch(op) and
700            op._control_flow_context is not None and
701            op._control_flow_context.IsWhileContext() and
702            op._control_flow_context ==
703            ops.get_default_graph()._get_control_flow_context()):
704          _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs)
705        # pylint: enable=protected-access
706
707        if (grad_fn or is_func_call) and has_out_grads:
708          # NOTE: If _AggregatedGrads didn't compute a value for the i'th
709          # output, it means that the cost does not depend on output[i],
710          # therefore dC/doutput[i] is 0.
711          for i, out_grad in enumerate(out_grads):
712            if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
713                (not grad_fn and is_func_call) or IsTrainable(op.outputs[i])):
714              # Only trainable outputs or outputs for a function call that
715              # will use SymbolicGradient get a zero gradient. Gradient
716              # functions should ignore the gradient for other outputs.
717              # TODO(apassos) gradients of resource handles might be an
718              # issue here because of zeros.
719              if loop_state:
720                out_grads[i] = loop_state.ZerosLike(op, i)
721              else:
722                out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
723          with ops.name_scope(op.name + "_grad"):
724            # pylint: disable=protected-access
725            with src_graph._original_op(op):
726              # pylint: enable=protected-access
727              if grad_fn:
728                # If grad_fn was found, do not use SymbolicGradient even for
729                # functions.
730                in_grads = _MaybeCompile(grad_scope, op, func_call,
731                                         lambda: grad_fn(op, *out_grads))
732              else:
733                # For function call ops, we add a 'SymbolicGradient'
734                # node to the graph to compute gradients.
735                in_grads = _MaybeCompile(grad_scope, op, func_call,
736                                         lambda: _SymGrad(op, out_grads))
737              in_grads = _AsList(in_grads)
738              _VerifyGeneratedGradients(in_grads, op)
739              if gate_gradients and len([x for x in in_grads
740                                         if x is not None]) > 1:
741                with ops.device(None):
742                  with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
743                      None,
744                      gradient_uid,
745                      ignore_existing=True):
746                    in_grads = control_flow_ops.tuple(in_grads)
747          _LogOpGradients(op, out_grads, in_grads)
748        else:
749          # If no grad_fn is defined or none of out_grads is available,
750          # just propagate a list of None backwards.
751          in_grads = [None] * len(_Inputs(op, xs))
752        # Note: we don't filter out eager inputs here because the inputs need to
753        # line up with in_grads.
754        for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
755          if in_grad is not None:
756            if (isinstance(in_grad, ops.Tensor) and
757                t_in.dtype != dtypes.resource):
758              try:
759                in_grad.set_shape(t_in.get_shape())
760              except ValueError:
761                raise ValueError(
762                    "Incompatible shapes between op input and calculated "
763                    "input gradient.  Forward operation: %s.  Input index: %d. "
764                    "Original input shape: %s.  "
765                    "Calculated input gradient shape: %s" %
766                    (op.name, i, t_in.shape, in_grad.shape))
767            if not isinstance(t_in, ops.EagerTensor):
768              _SetGrad(grads, t_in, in_grad)
769        if loop_state:
770          loop_state.ExitGradWhileContext(op, before=False)
771
772      # Update pending count for the inputs of op and enqueue ready ops.
773      _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
774                                    xs)
775
776  if loop_state:
777    loop_state.PostProcessing()
778  return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
779
780
781def _HasAnyNotNoneGrads(grads, op):
782  """Return true iff op has real gradient."""
783  out_grads = _GetGrads(grads, op)
784  for out_grad in out_grads:
785    if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
786      return True
787    if out_grad and isinstance(out_grad, collections.Sequence):
788      if any(g is not None for g in out_grad):
789        return True
790  return False
791
792
793def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
794                                  xs):
795  """Update pending count for the inputs of op and enqueue ready ops."""
796  for x in _NonEagerInputs(op, xs):
797    pending_count[x.op] -= 1
798    ready = (pending_count[x.op] == 0)
799    if loop_state and not ready:
800      ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op)
801    if ready:
802      if control_flow_util.IsLoopExit(x.op):
803        # if x is an exit without real gradient, defer processing them.
804        grad_state = loop_state.GetGradState(x.op, before=False)
805        grad_state.deferred_exits.append(x)
806        grad_state.pending_exits_count -= 1
807        if grad_state.pending_exits_count == 0:
808          # We now have all the exits so process them.
809          has_not_none_grad = False
810          for y in grad_state.deferred_exits:
811            if _HasAnyNotNoneGrads(grads, y.op):
812              has_not_none_grad = True
813              queue.append(y.op)
814            else:
815              grad_state.unused_exits.append(y)
816          if has_not_none_grad:
817            # For an unused exit, if it has trainable outputs, backprop
818            # a zero gradient. Otherwise, just ignore it.
819            for y in grad_state.unused_exits:
820              if IsTrainable(y):
821                _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
822              queue.append(y.op)
823          else:
824            # All exits are "unused" so use None as gradient.
825            for y in grad_state.unused_exits:
826              queue.append(y.op)
827      else:
828        queue.append(x.op)
829
830
831def _SetGrad(grads, t, grad):
832  """Sets gradient "grad" in "grads" for tensor "t"."""
833  op = t.op
834  op_grads = grads.get(op)
835  if not op_grads:
836    op_grads = [[] for _ in xrange(len(op.outputs))]
837    grads[op] = op_grads
838  t_grads = op_grads[t.value_index]
839  if isinstance(t_grads, list):
840    t_grads.append(grad)
841  else:
842    assert control_flow_util.IsLoopSwitch(op)
843    op_grads[t.value_index] = grad
844
845
846def _GetGrad(grads, t, unconnected_gradients):
847  """Gets gradient for tensor "t"."""
848  op = t.op
849  op_grads = grads.get(op)
850  if not op_grads:
851    if unconnected_gradients == UnconnectedGradients.ZERO:
852      t_dtype = t.dtype if t.dtype != dtypes.resource else dtypes.float32
853      return array_ops.zeros_like(t, dtype=t_dtype)
854    elif unconnected_gradients == UnconnectedGradients.NONE:
855      return None
856    else:
857      raise ValueError(
858          "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
859
860  t_grad = op_grads[t.value_index]
861  assert not isinstance(
862      t_grad, list), ("gradients list should have been aggregated by now.")
863  return t_grad
864
865
866def _GetGrads(grads, op):
867  """Gets all gradients for op."""
868  if op in grads:
869    return grads[op]
870  else:
871    return [[] for _ in xrange(len(op.outputs))]
872
873
874def _HandleNestedIndexedSlices(grad):
875  assert isinstance(grad, ops.IndexedSlices)
876  if isinstance(grad.values, ops.Tensor):
877    return grad
878  else:
879    assert isinstance(grad.values, ops.IndexedSlices)
880    g = _HandleNestedIndexedSlices(grad.values)
881    return ops.IndexedSlices(g.values, array_ops.gather(
882        grad.indices, g.indices), g.dense_shape)
883
884
885def _AccumulatorShape(inputs):
886  shape = tensor_shape.unknown_shape()
887  for i in inputs:
888    if isinstance(i, ops.Tensor):
889      shape = shape.merge_with(i.get_shape())
890  return shape
891
892
893def _LogOpGradients(op, out_grads, in_grads):
894  """Log the in and out grads of an op."""
895  logging.vlog(1, "Gradient for '" + op.name + "'")
896
897  def _FilterGrad(x):
898    if x is None:
899      return False
900    if isinstance(x, (list, tuple)):
901      return bool(x)
902    else:
903      return True
904
905  logging.vlog(1, "  in  --> %s",
906               ", ".join([x.name for x in out_grads if _FilterGrad(x)]))
907  logging.vlog(1, "  out --> %s",
908               ", ".join([x.name for x in in_grads if _FilterGrad(x)]))
909
910
911def _MultiDeviceAddN(tensor_list, gradient_uid):
912  """Adds tensors from potentially multiple devices."""
913  # Basic function structure comes from control_flow_ops.group().
914  # Sort tensors according to their devices.
915  tensors_on_device = collections.defaultdict(lambda: [])
916  for tensor in tensor_list:
917    tensors_on_device[tensor.device].append(tensor)
918
919  # For each device, add the tensors on that device first.
920  # Then gather the partial sums from multiple devices.
921  # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion.
922  # E.g., aggregate per GPU, then per task, and so on.
923  summands = []
924
925  def DeviceKey(dev):
926    return "" if dev is None else dev
927
928  for dev in sorted(tensors_on_device, key=DeviceKey):
929    tensors = tensors_on_device[dev]
930    with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
931        tensors[0].op,
932        gradient_uid,
933        ignore_existing=True):
934      summands.append(math_ops.add_n(tensors))
935
936  return math_ops.add_n(summands)
937
938
939@tf_export("AggregationMethod")
940class AggregationMethod(object):
941  """A class listing aggregation methods used to combine gradients.
942
943  Computing partial derivatives can require aggregating gradient
944  contributions. This class lists the various methods that can
945  be used to combine gradients in the graph:
946
947  *  `ADD_N`: All of the gradient terms are summed as part of one
948     operation using the "AddN" op. It has the property that all
949     gradients must be ready before any aggregation is performed.
950  *  `DEFAULT`: The system-chosen default aggregation method.
951  """
952  ADD_N = 0
953  DEFAULT = ADD_N
954  # The following are experimental and may not be supported in future releases.
955  EXPERIMENTAL_TREE = 1
956  EXPERIMENTAL_ACCUMULATE_N = 2
957
958
959def _AggregatedGrads(grads,
960                     op,
961                     gradient_uid,
962                     loop_state,
963                     aggregation_method=None):
964  """Get the aggregated gradients for op.
965
966  Args:
967    grads: The map of memoized gradients.
968    op: The op to get gradients for.
969    gradient_uid: A unique identifier within the graph indicating
970      which invocation of gradients is being executed. Used to cluster
971      ops for compilation.
972    loop_state: An object for maintaining the state of the while loops in the
973                graph. It is of type ControlFlowState. None if the graph
974                contains no while loops.
975    aggregation_method: Specifies the method used to combine gradient terms.
976      Accepted values are constants defined in the class `AggregationMethod`.
977
978  Returns:
979    A list of gradients, one per each output of `op`. If the gradients
980      for a particular output is a list, this function aggregates it
981      before returning.
982
983  Raises:
984    TypeError: if the incoming grads are not Tensors or IndexedSlices.
985    ValueError: if the arguments are invalid.
986
987  """
988  if aggregation_method is None:
989    aggregation_method = AggregationMethod.DEFAULT
990  if aggregation_method not in [
991      AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE,
992      AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
993  ]:
994    raise ValueError(
995        "Invalid aggregation_method specified %s." % aggregation_method)
996  out_grads = _GetGrads(grads, op)
997  for i, out_grad in enumerate(out_grads):
998    if loop_state:
999      if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
1000        assert control_flow_util.IsLoopSwitch(op)
1001        continue
1002    # Grads have to be Tensors or IndexedSlices
1003    if (isinstance(out_grad, collections.Sequence) and not all(
1004        isinstance(g, (ops.Tensor, ops.IndexedSlices))
1005        for g in out_grad
1006        if g is not None
1007    )):
1008      raise TypeError("gradients have to be either all Tensors "
1009                      "or all IndexedSlices")
1010    # Aggregate multiple gradients, and convert [] to None.
1011    if out_grad:
1012      if len(out_grad) < 2:
1013        used = "nop"
1014        out_grads[i] = out_grad[0]
1015      elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None):
1016        tensor_shape = _AccumulatorShape(out_grad)
1017        if (aggregation_method == AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
1018            and len(out_grad) > 2 and tensor_shape.is_fully_defined()):
1019          # The benefit of using AccumulateN is that its inputs can be combined
1020          # in any order and this can allow the expression to be evaluated with
1021          # a smaller memory footprint.  When used with gpu_allocator_retry,
1022          # it is possible to compute a sum of terms which are much larger than
1023          # total GPU memory.
1024          # AccumulateN can currently only be used if we know the shape for
1025          # an accumulator variable.  If this is not known, or if we only have
1026          # 2 grads then we fall through to the "tree" case below.
1027          used = "accumulate_n"
1028          out_grads[i] = math_ops.accumulate_n(out_grad)
1029        elif aggregation_method in [
1030            AggregationMethod.EXPERIMENTAL_TREE,
1031            AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
1032        ]:
1033          # Aggregate all gradients by doing pairwise sums: this may
1034          # reduce performance, but it can improve memory because the
1035          # gradients can be released earlier.
1036          #
1037          # TODO(vrv): Consider replacing this with a version of
1038          # tf.AddN() that eagerly frees its inputs as soon as they are
1039          # ready, so the order of this tree does not become a problem.
1040          used = "tree"
1041          with ops.name_scope(op.name + "_gradient_sum"):
1042            running_sum = out_grad[0]
1043            for grad in out_grad[1:]:
1044              running_sum = math_ops.add_n([running_sum, grad])
1045            out_grads[i] = running_sum
1046        else:
1047          used = "add_n"
1048          out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid)
1049        logging.vlog(2, "  _AggregatedGrads %d x %s using %s", len(out_grad),
1050                     tensor_shape, used)
1051      else:
1052        out_grads[i] = _AggregateIndexedSlicesGradients(out_grad)
1053    else:  # not out_grad
1054      # out_grads[i] is [], thus its aggregation is simply None.
1055      out_grads[i] = None
1056  return out_grads
1057
1058
1059def _AggregateIndexedSlicesGradients(grads):
1060  """Aggregates gradients of type `IndexedSlices` by concatenation."""
1061  if len(grads) < 1:
1062    return None
1063  elif len(grads) == 1:
1064    return grads[0]
1065  else:
1066    grads = math_ops._as_indexed_slices_list(  # pylint: disable=protected-access
1067        [g for g in grads if g is not None])
1068    grads = [_HandleNestedIndexedSlices(x) for x in grads]  # pylint: disable=protected-access
1069    # Form IndexedSlices out of the concatenated values and indices.
1070    concat_grad = ops.IndexedSlices(
1071        array_ops.concat([x.values for x in grads], axis=0),
1072        array_ops.concat([x.indices for x in grads], axis=0),
1073        grads[0].dense_shape)
1074
1075    return concat_grad
1076