• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements the graph generation for computation of gradients."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23
24from six.moves import xrange, zip  # pylint: disable=redefined-builtin
25
26from tensorflow.core.framework import attr_value_pb2
27from tensorflow.python import pywrap_tfe
28from tensorflow.python.eager import backprop
29from tensorflow.python.eager import backprop_util
30from tensorflow.python.eager import context
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import function as framework_function
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework.func_graph import FuncGraph
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import control_flow_state
39from tensorflow.python.ops import control_flow_util
40from tensorflow.python.ops import default_gradient
41from tensorflow.python.ops import functional_ops
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops import resource_variable_ops
44from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.python.util import compat
47from tensorflow.python.util import object_identity
48from tensorflow.python.util.compat import collections_abc
49from tensorflow.python.util.tf_export import tf_export
50
51
52def _MarkReachedOps(from_ops, reached_ops, func_graphs):
53  """Mark all ops reached from "from_ops".
54
55  Args:
56    from_ops: list of Operations.
57    reached_ops: set of Operations.
58    func_graphs: list of FuncGraphs. This method will traverse through
59      these functions if they capture from_ops or any reachable ops.
60  """
61  queue = collections.deque()
62  queue.extend(from_ops)
63  while queue:
64    op = queue.popleft()
65    if op not in reached_ops:
66      reached_ops.add(op)
67      for output in op.outputs:
68        if _IsBackpropagatable(output):
69          queue.extend(_Consumers(output, func_graphs))
70
71
72def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
73                  xs_set):
74  """Initialize the pending count for ops between two lists of Operations.
75
76  'pending_count[op]' indicates the number of backprop inputs
77  to this operation.
78
79  Args:
80    to_ops: list of Operations.
81    from_ops: list of Operations.
82    colocate_gradients_with_ops: Python bool.  See docstring of gradients().
83    func_graphs: list of FuncGraphs. This method will traverse through
84      these functions if they capture from_ops or any reachable ops. This is
85      useful if to_ops occur in a function and from_ops are in an outer function
86      or graph.
87    xs_set: ObjectIdentitySet of Tensors.
88
89  Returns:
90    A tuple containing: (1) the subset of to_ops reachable from from_ops by a
91    path of zero or more backpropagatable tensors, (2) a mapping from operation
92    to the number of backprop inputs to that op, and (3) a ControlFlowState
93    object which is not None if the ops between from_ops and to_ops contain
94    control flow loops.
95  """
96  # Mark reachable ops from from_ops.
97  reached_ops = set()
98  _MarkReachedOps(from_ops, reached_ops, func_graphs)
99  # X in reached_ops iff X is reachable from from_ops by a path of zero or more
100  # backpropagatable tensors.
101
102  reachable_to_ops = set(op for op in to_ops if op in reached_ops)
103
104  # Mark between ops.
105  between_ops = set()
106  between_op_list = []
107  queue = collections.deque()
108  queue.extend(to_ops)
109  while queue:
110    op = queue.popleft()
111    # We are interested in this op.
112    if op in reached_ops:
113      between_ops.add(op)
114      between_op_list.append(op)
115      # Clear the boolean so we won't add the inputs again.
116      reached_ops.remove(op)
117      for inp in _NonEagerInputs(op, xs_set):
118        queue.append(inp.op)
119  # X in between_ops iff X is on a path of zero or more backpropagatable tensors
120  # between from_ops and to_ops
121
122  # 'loop_state' is None if there are no while loops.
123  loop_state = control_flow_state.MaybeCreateControlFlowState(
124      between_op_list, between_ops, colocate_gradients_with_ops)
125
126  # Initialize pending count for between ops.
127  pending_count = collections.defaultdict(int)
128  for op in between_op_list:
129    for x in _NonEagerInputs(op, xs_set):
130      if x.op in between_ops:
131        pending_count[x.op] += 1
132
133  return reachable_to_ops, pending_count, loop_state
134
135
136def _AsList(x):
137  return x if isinstance(x, (list, tuple)) else [x]
138
139
140def _DefaultGradYs(grad_ys,
141                   ys,
142                   colocate_gradients_with_ops,
143                   gradient_uid="__unsupported__"):
144  """Fill in default values for grad_ys.
145
146  Args:
147    grad_ys: List of gradients, can contain None.
148    ys: List of tensors.
149    colocate_gradients_with_ops: If True, try colocating gradients with
150      the corresponding op.
151    gradient_uid: A unique identifier within the graph indicating
152      which invocation of gradients is being executed. Used to cluster
153      ops for compilation.
154
155  Returns:
156    A list of gradients to use, without None.
157
158  Raises:
159    ValueError: If sizes of gradients and inputs don't match
160    TypeError: If type of any gradient is not valid for its input.
161  """
162  if len(grad_ys) != len(ys):
163    raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys)))
164  grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y")
165  new_grad_ys = []
166  for i, (y, grad_y) in enumerate(zip(ys, grad_ys)):
167    with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops):
168      if grad_y is None:
169        if y.dtype.is_complex:
170          raise TypeError(
171              "Gradients of complex tensors must set grad_ys (y.dtype = %r)" %
172              y.dtype)
173        new_grad_ys.append(
174            array_ops.ones(
175                array_ops.shape(y), dtype=y.dtype, name="grad_ys_%d" % i))
176        continue
177      if y.dtype.is_floating or y.dtype.is_integer:
178        if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
179          raise TypeError(
180              "Gradient type %s generated for real or "
181              "integer-valued tensor %s with type %s must be "
182              "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y,
183                                   dtypes.as_dtype(y.dtype).name))
184      elif y.dtype.is_complex:
185        if not grad_y.dtype.is_complex:
186          raise TypeError(
187              "Gradient type %s generated for complex-valued "
188              "tensor %s with type %s must be real" % (dtypes.as_dtype(
189                  grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
190      elif y.dtype == dtypes.variant:
191        if grad_y.dtype != dtypes.variant:
192          raise TypeError(
193              "Gradient type %s generated for variant "
194              "tensor %s with type %s must be variant" % (dtypes.as_dtype(
195                  grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
196      elif y.dtype == dtypes.resource:
197        # We assume y is the handle of a ResourceVariable. The gradient of a
198        # ResourceVariable should be a numeric value, not another resource.
199        if grad_y.dtype == dtypes.resource:
200          raise TypeError("Input gradient %s for resource tensor %s should not "
201                          "be a resource" % (grad_y, y))
202      else:
203        raise TypeError(
204            "Tensor %s with type %s must be numeric "
205            "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name))
206      # Create a grad_y tensor in the name scope of the gradient.
207      # Required for TensorArrays to identify which gradient call a
208      # grad_y value is coming from.
209      if isinstance(grad_y, ops.IndexedSlices):
210        new_grad_ys.append(
211            ops.IndexedSlices(
212                indices=(array_ops.identity(
213                    grad_y.indices, name="grad_ys_%d_indices" % i)
214                         if isinstance(grad_y.indices, ops.Tensor) else
215                         grad_y.indices),
216                values=(array_ops.identity(
217                    grad_y.values, name="grad_ys_%d_values" % i) if isinstance(
218                        grad_y.values, ops.Tensor) else grad_y.values),
219                dense_shape=(array_ops.identity(
220                    grad_y.dense_shape, name="grad_ys_%d_shape" % i)
221                             if isinstance(grad_y.dense_shape, ops.Tensor) else
222                             grad_y.dense_shape)))
223      else:
224        new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i))
225
226  return new_grad_ys
227
228
229def _IsBackpropagatable(tensor):
230  if backprop_util.IsTrainable(tensor):
231    return True
232  dtype = dtypes.as_dtype(tensor.dtype)
233  return dtype.base_dtype == dtypes.bfloat16
234
235
236def _VerifyGeneratedGradients(grads, op):
237  """Verify that gradients are valid in number and type.
238
239  Args:
240    grads: List of generated gradients.
241    op: Operation for which the gradients where generated.
242
243  Raises:
244    ValueError: if sizes of gradients and inputs don't match.
245    TypeError: if type of any gradient is not valid for its input.
246  """
247  # While ops have inputs added to them during the gradient computation, so we
248  # skip the below check. See while_v2 for details.
249  if op.type == "While" or op.type == "StatelessWhile":
250    return
251
252  if len(grads) != len(op.inputs):
253    raise ValueError("Num gradients %d generated for op %s do not match num "
254                     "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
255
256
257def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set):
258  """The set of ops that terminate the gradient computation.
259
260  This computes the frontier of the forward graph *before* which backprop
261  should stop. Operations in the returned set will not be differentiated.
262  This set is defined as the subset of `from_ops` containing ops that have
263  no predecessor in `from_ops`. `pending_count` is the result of
264  `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops`
265  iff pending_count[op] > 0.
266
267  In addition, none of `stop_gradient_ops` will be differentiated.
268
269  Args:
270    from_ops: list of Operations.
271    stop_gradient_ops: list of Operations never to backprop through.
272    pending_count: mapping from operation to number of backprop inputs.
273    xs_set: ObjectIdentitySet of Tensors.
274
275  Returns:
276    The set of operations.
277  """
278  stop_ops = set()
279  for op in from_ops:
280    is_stop_op = True
281    for inp in _NonEagerInputs(op, xs_set):
282      if pending_count[inp.op] > 0:
283        is_stop_op = False
284        break
285    if is_stop_op:
286      stop_ops.add(op)
287  stop_ops.update(op for op in stop_gradient_ops)
288  return stop_ops
289
290
291@contextlib.contextmanager
292def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):  # pylint: disable=invalid-name
293  """Context to colocate with `op` if `colocate_gradients_with_ops`."""
294  if colocate_gradients_with_ops:
295    with ops._colocate_with_for_gradient(op, gradient_uid):  # pylint: disable=protected-access
296      yield
297  else:
298    yield
299
300
301def _IsPartitionedCall(op):
302  return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall"
303
304
305def _SymGrad(op, out_grads):
306  """Backprop through a function call node op given its outputs' gradients."""
307  f_in = [x for x in op.inputs] + out_grads
308  f_types = [default_gradient.get_zeros_dtype(x) for x in op.inputs]
309  f = attr_value_pb2.NameAttrList()
310  if _IsPartitionedCall(op):
311    f.name = op.get_attr("f").name
312  else:
313    f.name = op.type
314  for k in op.node_def.attr:
315    f.attr[k].CopyFrom(op.node_def.attr[k])
316  in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f)
317  return in_grads
318
319
320def _MaybeCompile(scope, op, func, grad_fn):
321  """Compile the calculation in grad_fn if op was marked as compiled."""
322  scope = scope.rstrip("/").replace("/", "_")
323  if func is not None:
324    xla_compile = func.definition.attr["_XlaCompile"].b
325    xla_separate_compiled_gradients = func.definition.attr[
326        "_XlaSeparateCompiledGradients"].b
327    xla_scope = func.definition.attr["_XlaScope"].s.decode()
328  else:
329    try:
330      xla_compile = op.get_attr("_XlaCompile")
331      xla_separate_compiled_gradients = op.get_attr(
332          "_XlaSeparateCompiledGradients")
333      xla_scope = op.get_attr("_XlaScope").decode()
334    except ValueError:
335      xla_compile = False
336
337  if not xla_compile:
338    return grad_fn()  # Exit early
339
340  # If the gradients are supposed to be compiled separately, we give them a
341  # _XlaScope name that is based on the name_scope of the gradients.  Otherwise
342  # they just inherit the existing _XlaScope name, which lets them be merged
343  # together with the non-gradient computation.
344  if xla_separate_compiled_gradients:
345    xla_grad_scope = "%s_grad_%s" % (xla_scope, scope)
346  else:
347    xla_grad_scope = xla_scope
348
349  attrs = {
350      "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile),
351      "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode())
352  }
353  with ops.get_default_graph()._attr_scope(attrs):  # pylint: disable=protected-access
354    return grad_fn()
355
356
357def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set):
358  """Raises an error if we backprop through a loop var."""
359  # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
360  # message.
361  target_op = None
362  queue = collections.deque([op])
363  visited = set()
364  while queue:
365    curr_op = queue.popleft()
366    if curr_op in visited: continue
367    visited.add(curr_op)
368    if curr_op in from_ops:
369      target_op = curr_op
370      break
371    queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set))
372  assert target_op
373  raise ValueError(
374      "Cannot compute gradient inside while loop with respect to op '%s'. "
375      "We do not support taking the gradient wrt or through the initial value "
376      "of a loop variable. Gradients can be computed through loop invariants "
377      "or wrt the input parameters to the loop body."
378      % target_op.name)
379
380
381def _IsFunction(graph):
382  return (isinstance(graph, FuncGraph) or
383          isinstance(graph, framework_function._FuncGraph))  # pylint: disable=protected-access
384
385
386def _Captures(func_graph):
387  if isinstance(func_graph, FuncGraph):
388    return func_graph.captures
389  else:
390    assert isinstance(func_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
391    return func_graph.captures
392
393
394def _MaybeCaptured(t):
395  """If t is a captured value placeholder, returns the original captured value.
396
397  Args:
398    t: Tensor
399
400  Returns:
401    A tensor, potentially from a different Graph/FuncGraph.
402  """
403  # pylint: disable=protected-access
404  if (not isinstance(t, ops.EagerTensor) and
405      _IsFunction(t.op.graph) and t.op.type == "Placeholder"):
406    for input_t, placeholder_t in _Captures(t.op.graph):
407      if t is placeholder_t:
408        return _MaybeCaptured(input_t)
409  # pylint: enable=protected-access
410  return t
411
412
413def _NonEagerInputs(op, xs_set):
414  """Returns the inputs of op, crossing closure boundaries where necessary.
415
416  Does not return any captured EagerTensors, i.e., the number of tensors
417  returned may be less than the actual number of inputs.
418
419  Args:
420    op: Operation
421    xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
422
423  Returns:
424    A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
425    is in a FuncGraph and has captured inputs.
426  """
427  return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)]
428
429
430# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
431# _GradientsHelper a class with xs as a member variable.
432def _Inputs(op, xs_set):
433  """Returns the inputs of op, crossing closure boundaries where necessary.
434
435  Args:
436    op: Operation
437    xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
438
439  Returns:
440    A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
441    is in a FuncGraph and has captured inputs.
442  """
443  if _IsFunction(op.graph):  # pylint: disable=protected-access
444    inputs = []
445    for t in op.inputs:
446      # If we're differentiating w.r.t. `t`, do not attempt to traverse through
447      # it to a captured value. The algorithm needs to "see" `t` in this case,
448      # even if it's a function input for a captured value, whereas usually we'd
449      # like to traverse through these closures as if the captured value was the
450      # direct input to op.
451      if t not in xs_set:
452        t = _MaybeCaptured(t)
453      inputs.append(t)
454    return inputs
455  else:
456    return op.inputs
457
458
459def _Consumers(t, func_graphs):
460  """Returns the consumers of t, crossing closure boundaries where necessary.
461
462  Args:
463    t: Tensor
464    func_graphs: a list of FuncGraphs that may have captured t.
465
466  Returns:
467    A list of tensors. The tensors will be from the current graph and/or
468    func_graphs.
469  """
470  consumers = t.consumers()
471  for func in func_graphs:
472    for input_t, placeholder in _Captures(func):
473      if input_t is t:
474        consumers.extend(_Consumers(placeholder, func_graphs))
475  return consumers
476
477
478def _GradientsHelper(ys,
479                     xs,
480                     grad_ys=None,
481                     name="gradients",
482                     colocate_gradients_with_ops=False,
483                     gate_gradients=False,
484                     aggregation_method=None,
485                     stop_gradients=None,
486                     unconnected_gradients=UnconnectedGradients.NONE,
487                     src_graph=None):
488  """Implementation of gradients()."""
489  if context.executing_eagerly():
490    raise RuntimeError("tf.gradients is not supported when eager execution "
491                       "is enabled. Use tf.GradientTape instead.")
492  if src_graph is None:
493    src_graph = ops.get_default_graph()
494  try:
495    unconnected_gradients = UnconnectedGradients(unconnected_gradients)
496  except ValueError:
497    raise ValueError(
498        "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
499
500  # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
501  # ancestor graphs. This is necessary for correctly handling captured values.
502  func_graphs = []
503  curr_graph = src_graph
504  while _IsFunction(curr_graph):
505    func_graphs.append(curr_graph)
506    if isinstance(curr_graph, FuncGraph):
507      curr_graph = curr_graph.outer_graph
508    else:
509      assert isinstance(curr_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
510      curr_graph = curr_graph._outer_graph  # pylint: disable=protected-access
511
512  ys = _AsList(ys)
513  xs = _AsList(xs)
514  stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
515  if grad_ys is None:
516    grad_ys = [None] * len(ys)
517  else:
518    grad_ys = _AsList(grad_ys)
519
520  with ops.name_scope(
521      name, "gradients",
522      list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
523    # Get a uid for this call to gradients that can be used to help
524    # cluster ops for compilation.
525    gradient_uid = ops.get_default_graph().unique_name("uid")
526    ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
527    xs = [
528        x.handle if resource_variable_ops.is_resource_variable(x) else x
529        for x in xs
530    ]
531    xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
532        xs, name="x", as_ref=True)
533    xs_set = object_identity.ObjectIdentitySet(xs)
534    grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
535                             gradient_uid)
536
537    # The approach we take here is as follows: Create a list of all ops in the
538    # subgraph between the ys and xs.  Visit these ops in reverse order of ids
539    # to ensure that when we visit an op the gradients w.r.t its outputs have
540    # been collected.  Then aggregate these gradients if needed, call the op's
541    # gradient function, and add the generated gradients to the gradients for
542    # its input.
543
544    # Initialize the pending count for ops in the connected subgraph from ys
545    # to the xs.
546    to_ops = [t.op for t in ys]
547    from_ops = [t.op for t in xs]
548    stop_gradient_ops = [t.op for t in stop_gradients]
549    reachable_to_ops, pending_count, loop_state = _PendingCount(
550        to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set)
551
552    # Iterate over the collected ops.
553    #
554    # grads: op => list of gradients received on each output endpoint of the
555    # op.  The gradients for each endpoint are initially collected as a list.
556    # When it is time to call the op's gradient function, for each endpoint we
557    # aggregate the list of received gradients into a Add() Operation if there
558    # is more than one.
559    grads = {}
560
561    # Add the initial gradients for the ys.
562    for y, grad_y in zip(ys, grad_ys):
563      _SetGrad(grads, y, grad_y)
564
565    # Initialize queue with to_ops.
566    queue = collections.deque()
567    # Add the ops in 'to_ops' into the queue.
568    to_ops_set = set()
569    for op in to_ops:
570      # 'ready' handles the case where one output gradient relies on
571      # another output's gradient.
572      ready = (pending_count[op] == 0)
573      if ready and op not in to_ops_set and op in reachable_to_ops:
574        to_ops_set.add(op)
575        queue.append(op)
576
577    if loop_state:
578      loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
579      for y in loop_exits:
580        if backprop_util.IsTrainable(y):
581          _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
582          queue.append(y.op)
583
584    stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set)
585    while queue:
586      # generate gradient subgraph for op.
587      op = queue.popleft()
588      with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
589        if loop_state:
590          loop_state.EnterGradWhileContext(op, before=True)
591        out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
592                                     aggregation_method)
593        if loop_state:
594          loop_state.ExitGradWhileContext(op, before=True)
595
596        grad_fn = None
597        func_call = None
598        is_partitioned_call = _IsPartitionedCall(op)
599        # pylint: disable=protected-access
600        is_func_call = (
601            src_graph._is_function(op.type) or is_partitioned_call)
602        # pylint: enable=protected-access
603        has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
604        if has_out_grads and (op not in stop_ops):
605          try:
606            grad_fn = ops.get_gradient_function(op)
607          except LookupError:
608            if is_func_call:
609              if is_partitioned_call:
610                func_name = compat.as_bytes(op.get_attr("f").name)
611                func_call = src_graph._get_function(  # pylint: disable=protected-access
612                    func_name)
613                # When a graph is imported, the FunctionDefs are not copied over
614                # to each sub-graph so we recursively search the outer graphs
615                # for the FunctionDef.
616                if not func_call and hasattr(src_graph, "outer_graph"):
617                  graph = src_graph.outer_graph
618                  while graph is not None:
619                    func_call = graph._get_function(func_name)  # pylint: disable=protected-access
620                    if func_call  is not None:
621                      break
622                    if hasattr(graph, "outer_graph"):
623                      graph = graph.outer_graph
624                    else:
625                      break
626              else:
627                func_call = src_graph._get_function(op.type)  # pylint: disable=protected-access
628              # Note that __defun is not set if the graph is
629              # imported. If it's set, we prefer to access the original
630              # defun.
631              func_call = getattr(op, "__defun", func_call)
632              grad_fn = func_call.python_grad_func
633            else:
634              raise LookupError(
635                  "No gradient defined for operation" +
636                  "'%s' (op type: %s). " %(op.name, op.type) +
637                  "In general every operation must have an associated " +
638                  "`@tf.RegisterGradient` for correct autodiff, which this " +
639                  "op is lacking. If you want to pretend this " +
640                  "operation is a constant in your program, you may insert " +
641                  "`tf.stop_gradient`. This can be useful to silence the " +
642                  "error in cases where you know gradients are not needed, " +
643                  "e.g. the forward pass of tf.custom_gradient. " +
644                  "Please see more details in " +
645                  "https://www.tensorflow.org/api_docs/python/tf/custom_gradient.")  # pylint: disable=line-too-long
646        if loop_state:
647          loop_state.EnterGradWhileContext(op, before=False)
648
649        # NOTE(skyewm): We don't support computing gradients wrt a loop variable
650        # unless it's within the context of a single iteration (i.e. the
651        # gradient is wrt to the loop parameter in the body function, not wrt or
652        # through the initial value). This means if we're in a while loop
653        # context, we should never see a switch node from this context.
654        # pylint: disable=protected-access
655        if (control_flow_util.IsSwitch(op) and
656            op._control_flow_context is not None and
657            op._control_flow_context.IsWhileContext() and
658            op._control_flow_context ==
659            ops.get_default_graph()._get_control_flow_context()):
660          _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set)
661        # pylint: enable=protected-access
662
663        if (grad_fn or is_func_call) and has_out_grads:
664          # NOTE: If _AggregatedGrads didn't compute a value for the i'th
665          # output, it means that the cost does not depend on output[i],
666          # therefore dC/doutput[i] is 0.
667          for i, out_grad in enumerate(out_grads):
668            if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
669                (not grad_fn and is_func_call)
670                or backprop_util.IsTrainable(op.outputs[i])):
671              # Only trainable outputs or outputs for a function call that
672              # will use SymbolicGradient get a zero gradient. Gradient
673              # functions should ignore the gradient for other outputs.
674              # TODO(apassos) gradients of resource handles might be an
675              # issue here because of zeros.
676              if loop_state:
677                out_grads[i] = loop_state.ZerosLikeV1WhileLoop(op, i)
678              elif default_gradient.supports_default_grad(op.outputs[i]):
679                # TODO(b/143286622): The supports_default_grad check is needed
680                # because While op emits non-differentiable resource tensors
681                # as outputs. Remove this check when that is not the case.
682                out_grads[i] = control_flow_state.ZerosLike(op, i)
683          with ops.name_scope(op.name + "_grad"):
684            # pylint: disable=protected-access
685            with src_graph._original_op(op):
686              # pylint: enable=protected-access
687              if grad_fn:
688                # If grad_fn was found, do not use SymbolicGradient even for
689                # functions.
690                in_grads = _MaybeCompile(grad_scope, op, func_call,
691                                         lambda: grad_fn(op, *out_grads))
692              else:
693                # For function call ops, we add a 'SymbolicGradient'
694                # node to the graph to compute gradients.
695                in_grads = _MaybeCompile(grad_scope, op, func_call,
696                                         lambda: _SymGrad(op, out_grads))
697              in_grads = _AsList(in_grads)
698              _VerifyGeneratedGradients(in_grads, op)
699              if gate_gradients and len([x for x in in_grads
700                                         if x is not None]) > 1:
701                with ops.device(None):
702                  with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
703                      None,
704                      gradient_uid,
705                      ignore_existing=True):
706                    in_grads = control_flow_ops.tuple(in_grads)
707          _LogOpGradients(op, out_grads, in_grads)
708        else:
709          # If no grad_fn is defined or none of out_grads is available,
710          # just propagate a list of None backwards.
711          in_grads = [None] * len(_Inputs(op, xs_set))
712        # Note: we don't filter out eager inputs here because the inputs need to
713        # line up with in_grads.
714        for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)):
715          if in_grad is not None:
716            if (isinstance(in_grad, ops.Tensor) and
717                t_in.dtype != dtypes.resource):
718              try:
719                in_grad.set_shape(t_in.get_shape())
720              except ValueError:
721                raise ValueError(
722                    "Incompatible shapes between op input and calculated "
723                    "input gradient.  Forward operation: %s.  Input index: %d. "
724                    "Original input shape: %s.  "
725                    "Calculated input gradient shape: %s" %
726                    (op.name, i, t_in.shape, in_grad.shape))
727            if not isinstance(t_in, ops.EagerTensor):
728              _SetGrad(grads, t_in, in_grad)
729        if loop_state:
730          loop_state.ExitGradWhileContext(op, before=False)
731
732      # Update pending count for the inputs of op and enqueue ready ops.
733      _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
734                                    xs_set)
735
736  if loop_state:
737    loop_state.PostProcessing()
738  return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
739
740
741def _HasAnyNotNoneGrads(grads, op):
742  """Return true iff op has real gradient."""
743  out_grads = _GetGrads(grads, op)
744  for out_grad in out_grads:
745    if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
746      return True
747    if out_grad and isinstance(out_grad, collections_abc.Sequence):
748      if any(g is not None for g in out_grad):
749        return True
750  return False
751
752
753def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
754                                  xs_set):
755  """Update pending count for the inputs of op and enqueue ready ops."""
756  for x in _NonEagerInputs(op, xs_set):
757    pending_count[x.op] -= 1
758    ready = (pending_count[x.op] == 0)
759    if loop_state and not ready:
760      ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op)
761    if ready:
762      if control_flow_util.IsLoopExit(x.op):
763        # if x is an exit without real gradient, defer processing them.
764        grad_state = loop_state.GetGradState(x.op, before=False)
765        grad_state.deferred_exits.append(x)
766        grad_state.pending_exits_count -= 1
767        if grad_state.pending_exits_count == 0:
768          # We now have all the exits so process them.
769          has_not_none_grad = False
770          for y in grad_state.deferred_exits:
771            if _HasAnyNotNoneGrads(grads, y.op):
772              has_not_none_grad = True
773              queue.append(y.op)
774            else:
775              grad_state.unused_exits.append(y)
776          if has_not_none_grad:
777            # For an unused exit, if it has trainable outputs, backprop
778            # a zero gradient. Otherwise, just ignore it.
779            for y in grad_state.unused_exits:
780              if backprop_util.IsTrainable(y):
781                _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
782              queue.append(y.op)
783          else:
784            # All exits are "unused" so use None as gradient.
785            for y in grad_state.unused_exits:
786              queue.append(y.op)
787      else:
788        queue.append(x.op)
789
790
791def _SetGrad(grads, t, grad):
792  """Sets gradient "grad" in "grads" for tensor "t"."""
793  op = t.op
794  op_grads = grads.get(op)
795  if not op_grads:
796    op_grads = [[] for _ in xrange(len(op.outputs))]
797    grads[op] = op_grads
798  t_grads = op_grads[t.value_index]
799  if isinstance(t_grads, list):
800    t_grads.append(grad)
801  else:
802    assert control_flow_util.IsLoopSwitch(op)
803    op_grads[t.value_index] = grad
804
805
806def _ZerosLike(t):
807  t_dtype = default_gradient.get_zeros_dtype(t)
808  if t.dtype == dtypes.resource:
809    return array_ops.zeros(
810        resource_variable_ops.variable_shape(t), dtype=t_dtype)
811  else:
812    return array_ops.zeros_like(t, dtype=t_dtype)
813
814
815def _GetGrad(grads, t, unconnected_gradients):
816  """Gets gradient for tensor "t"."""
817  op = t.op
818  op_grads = grads.get(op)
819  if not op_grads:
820    if unconnected_gradients == UnconnectedGradients.ZERO:
821      return _ZerosLike(t)
822    elif unconnected_gradients == UnconnectedGradients.NONE:
823      return None
824    else:
825      raise ValueError(
826          "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
827
828  t_grad = op_grads[t.value_index]
829  # This can happen if some other output of `t.op` has non-None grad.
830  if unconnected_gradients == UnconnectedGradients.ZERO and t_grad is None:
831    return _ZerosLike(t)
832
833  assert not isinstance(
834      t_grad, list), ("gradients list should have been aggregated by now.")
835  return t_grad
836
837
838def _GetGrads(grads, op):
839  """Gets all gradients for op."""
840  if op in grads:
841    return grads[op]
842  else:
843    return [[] for _ in xrange(len(op.outputs))]
844
845
846def _AccumulatorShape(inputs):
847  shape = tensor_shape.unknown_shape()
848  for i in inputs:
849    if isinstance(i, ops.Tensor):
850      shape = shape.merge_with(i.get_shape())
851  return shape
852
853
854def _LogOpGradients(op, out_grads, in_grads):
855  """Log the in and out grads of an op."""
856  logging.vlog(1, "Gradient for '" + op.name + "'")
857
858  def _FilterGrad(x):
859    if x is None:
860      return False
861    if isinstance(x, (list, tuple)):
862      return bool(x)
863    else:
864      return True
865
866  logging.vlog(1, "  in  --> %s",
867               ", ".join(x.name for x in out_grads if _FilterGrad(x)))
868  logging.vlog(1, "  out --> %s",
869               ", ".join(x.name for x in in_grads if _FilterGrad(x)))
870
871
872def _MultiDeviceAddN(tensor_list, gradient_uid):
873  """Adds tensors from potentially multiple devices."""
874  # Basic function structure comes from control_flow_ops.group().
875  # Sort tensors according to their devices.
876  tensors_on_device = collections.defaultdict(lambda: [])
877  for tensor in tensor_list:
878    tensors_on_device[tensor.device].append(tensor)
879
880  # For each device, add the tensors on that device first.
881  # Then gather the partial sums from multiple devices.
882  # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion.
883  # E.g., aggregate per GPU, then per task, and so on.
884  summands = []
885
886  def DeviceKey(dev):
887    return "" if dev is None else dev
888
889  for dev in sorted(tensors_on_device, key=DeviceKey):
890    tensors = tensors_on_device[dev]
891    with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
892        tensors[0].op,
893        gradient_uid,
894        ignore_existing=True):
895      summands.append(math_ops.add_n(tensors))
896
897  return math_ops.add_n(summands)
898
899
900@tf_export("AggregationMethod")
901class AggregationMethod(object):
902  """A class listing aggregation methods used to combine gradients.
903
904  Computing partial derivatives can require aggregating gradient
905  contributions. This class lists the various methods that can
906  be used to combine gradients in the graph.
907
908  The following aggregation methods are part of the stable API for
909  aggregating gradients:
910
911  *  `ADD_N`: All of the gradient terms are summed as part of one
912     operation using the "AddN" op (see `tf.add_n`). This
913     method has the property that all gradients must be ready and
914     buffered separately in memory before any aggregation is performed.
915  *  `DEFAULT`: The system-chosen default aggregation method.
916
917  The following aggregation methods are experimental and may not
918  be supported in future releases:
919
920  * `EXPERIMENTAL_TREE`: Gradient terms are summed in pairs using
921    the "AddN" op. This method of summing gradients may reduce
922    performance, but it can improve memory utilization because the
923    gradients can be released earlier.
924
925  """
926  ADD_N = 0
927  DEFAULT = ADD_N
928  # The following are experimental and may not be supported in future releases.
929  EXPERIMENTAL_TREE = 1
930  EXPERIMENTAL_ACCUMULATE_N = 2  # An alias for EXPERIMENTAL_ADD_N = 1
931
932
933def _AggregatedGrads(grads,
934                     op,
935                     gradient_uid,
936                     loop_state,
937                     aggregation_method=None):
938  """Get the aggregated gradients for op.
939
940  Args:
941    grads: The map of memoized gradients.
942    op: The op to get gradients for.
943    gradient_uid: A unique identifier within the graph indicating
944      which invocation of gradients is being executed. Used to cluster
945      ops for compilation.
946    loop_state: An object for maintaining the state of the while loops in the
947                graph. It is of type ControlFlowState. None if the graph
948                contains no while loops.
949    aggregation_method: Specifies the method used to combine gradient terms.
950      Accepted values are constants defined in the class `AggregationMethod`.
951
952  Returns:
953    A list of gradients, one per each output of `op`. If the gradients
954      for a particular output is a list, this function aggregates it
955      before returning.
956
957  Raises:
958    TypeError: if the incoming grads are not Tensors or IndexedSlices.
959    ValueError: if the arguments are invalid.
960
961  """
962  if aggregation_method is None:
963    aggregation_method = AggregationMethod.DEFAULT
964  if aggregation_method not in [
965      AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE,
966      AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
967  ]:
968    raise ValueError(
969        "Invalid aggregation_method specified %s." % aggregation_method)
970  out_grads = _GetGrads(grads, op)
971  for i, out_grad in enumerate(out_grads):
972    if loop_state:
973      if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
974        assert control_flow_util.IsLoopSwitch(op)
975        continue
976    # Grads have to be Tensors or IndexedSlices
977    if (isinstance(out_grad, collections_abc.Sequence) and not all(
978        isinstance(g, (ops.Tensor, ops.IndexedSlices))
979        for g in out_grad
980        if g is not None)):
981      raise TypeError("gradients have to be either all Tensors "
982                      "or all IndexedSlices")
983    # Aggregate multiple gradients, and convert [] to None.
984    if out_grad:
985      if len(out_grad) < 2:
986        used = "nop"
987        out_grads[i] = out_grad[0]
988      elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None):
989        tensor_shape = _AccumulatorShape(out_grad)
990        if aggregation_method in [
991            AggregationMethod.EXPERIMENTAL_TREE,
992            AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
993        ]:
994          # Aggregate all gradients by doing pairwise sums: this may
995          # reduce performance, but it can improve memory because the
996          # gradients can be released earlier.
997          #
998          # TODO(vrv): Consider replacing this with a version of
999          # tf.AddN() that eagerly frees its inputs as soon as they are
1000          # ready, so the order of this tree does not become a problem.
1001          used = "tree"
1002          with ops.name_scope(op.name + "_gradient_sum"):
1003            running_sum = out_grad[0]
1004            for grad in out_grad[1:]:
1005              running_sum = math_ops.add_n([running_sum, grad])
1006            out_grads[i] = running_sum
1007        else:
1008          used = "add_n"
1009          out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid)
1010        logging.vlog(2, "  _AggregatedGrads %d x %s using %s", len(out_grad),
1011                     tensor_shape, used)
1012      else:
1013        out_grads[i] = backprop.aggregate_indexed_slices_gradients(out_grad)  # pylint: disable=protected-access
1014    else:  # not out_grad
1015      # out_grads[i] is [], thus its aggregation is simply None.
1016      out_grads[i] = None
1017  return out_grads
1018
1019
1020# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
1021# unfortunately too slow to use here.
1022POSSIBLE_GRADIENT_TYPES_NONE = 0
1023POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1
1024POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
1025
1026
1027def PossibleTapeGradientTypes(tensors):
1028  """Determines whether and how `args` may require tape gradients."""
1029  return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors)
1030