• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""cond_v2 and gradient.
16
17This is a version of cond that emits a single If op, as well as the gradient
18function for If ops produced by cond_v2. This will eventually replace the
19current tf.cond implementation once it reaches feature and performance parity.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import collections
27
28from tensorflow.core.framework import types_pb2
29from tensorflow.python.eager import backprop_util
30from tensorflow.python.framework import auto_control_deps
31from tensorflow.python.framework import auto_control_deps_utils as acd
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors_impl
35from tensorflow.python.framework import func_graph as func_graph_module
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import tensor_shape
38from tensorflow.python.framework import tensor_util
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import control_flow_util
41from tensorflow.python.ops import control_flow_util_v2 as util
42from tensorflow.python.ops import default_gradient
43from tensorflow.python.ops import gen_dataset_ops
44from tensorflow.python.ops import gen_functional_ops
45from tensorflow.python.ops import gradients_util
46from tensorflow.python.ops import handle_data_util
47from tensorflow.python.ops import math_ops
48from tensorflow.python.util import nest
49
50
51# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify
52# that they aren't part of the official public API. These protected members
53# often need to be used by implementation code however. Rather than litter the
54# code with pylint comments, we ignore protected access violations for
55# readability.
56# pylint: disable=protected-access
57
58_COND = 1
59_CASE = 2
60
61
62def cond_v2(pred, true_fn, false_fn, name="cond"):
63  """Like tf.cond, except emits a single If op."""
64  if isinstance(pred, bool):
65    raise TypeError("pred must not be a Python bool", pred)
66
67  if not name:
68    name = "cond"
69
70  with ops.name_scope(name) as scope:
71    true_name = util.unique_fn_name(scope, "true")
72    false_name = util.unique_fn_name(scope, "false")
73
74    # Automatic control dependencies are added in defuns, but not in v1
75    # graphs. Propagate that behavior here.
76    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
77    pred = ops.convert_to_tensor(pred)
78    if (tensor_util.is_tf_type(pred) and
79        (pred.shape.dims is None or pred.shape.dims)):
80      pred = array_ops.squeeze_v2(pred)
81
82    true_graph = func_graph_module.func_graph_from_py_func(
83        true_name,
84        true_fn, [], {},
85        func_graph=util.CondBranchFuncGraph(
86            true_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
87        add_control_dependencies=add_control_dependencies,
88        op_return_value=pred)
89    false_graph = func_graph_module.func_graph_from_py_func(
90        false_name,
91        false_fn, [], {},
92        func_graph=util.CondBranchFuncGraph(
93            false_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
94        add_control_dependencies=add_control_dependencies,
95        op_return_value=pred)
96
97    verify_captures(_COND, [true_graph, false_graph])
98    return _build_cond(
99        pred,
100        true_graph,
101        false_graph,
102        true_graph.external_captures,
103        false_graph.external_captures,
104        building_gradient=False,
105        name=scope)
106
107
108@ops.RegisterGradient("StatelessIf")
109@ops.RegisterGradient("If")
110def _IfGrad(op, *grads):  # pylint: disable=invalid-name
111  """The gradient of an If op produced by cond_v2."""
112  # Get the if operator (this logic handles the case where op is a MockOp)
113  if_op = op.outputs[0].op
114  true_graph, false_graph = get_func_graphs(if_op)
115  # Note: op.graph != ops.get_default_graph() when we are computing the gradient
116  # of a nested cond.
117  assert true_graph.outer_graph == if_op.graph
118  assert false_graph.outer_graph == if_op.graph
119
120  # Create grad functions that compute the gradient of the true/false forward
121  # graphs. These functions will capture tensors from the forward pass
122  # functions.
123  true_grad_graph = _create_grad_func(
124      true_graph, grads, util.unique_grad_fn_name(true_graph.name))
125  false_grad_graph = _create_grad_func(
126      false_graph, grads, util.unique_grad_fn_name(false_graph.name))
127
128  # Replaces output None grads with zeros if at least one branch has non-None
129  # grad at that index.
130  _create_zeros_for_none_grads([true_graph, false_graph],
131                               [true_grad_graph, false_grad_graph])
132
133  if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite):
134    # Modify 'op' to output the intermediates needed by the grad functions. Note
135    # that all needed intermediates are wrapped in optionals. Each optional
136    # intermediate output will have a value iff its corresponding branch is
137    # taken.
138    # NOTE(skyewm): if there are any active sessions, this modification to `op`
139    # may make them unrunnable!
140
141    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
142      # XLA does not yet support optionals, so output intermediates directly and
143      # make them match via FakeParams, which can be converted to zeros in XLA.
144      # TODO(skyewm,jpienaar): can XLA support optionals?
145      true_intermediates = true_grad_graph.xla_intermediates
146      false_intermediates = false_grad_graph.xla_intermediates
147      extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
148          [true_graph, false_graph], [true_intermediates, false_intermediates])
149    else:
150      true_intermediates = true_grad_graph.wrapped_intermediates
151      false_intermediates = false_grad_graph.wrapped_intermediates
152      # Make outputs match by adding none optionals.
153      extra_true_outputs, extra_false_outputs = _make_intermediates_match(
154          [true_graph, false_graph], [true_intermediates, false_intermediates])
155
156    true_graph.outputs.extend(extra_true_outputs)
157    false_graph.outputs.extend(extra_false_outputs)
158    # TODO(skyewm): indicate it's an internal bug if this fails.
159    _check_same_outputs(_COND, [true_graph, false_graph])
160
161    true_graph.name += "_rewritten"
162    false_graph.name += "_rewritten"
163
164    if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph))
165    if_op._set_func_attr("else_branch",
166                         util.create_new_tf_function(false_graph))
167    if_op._set_type_list_attr("Tout", true_graph.output_types)
168    if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes)
169    if_op._add_outputs(
170        [t.dtype for t in extra_true_outputs],
171        [t.shape for t in extra_true_outputs])
172
173  # Resolve references to forward graph tensors in grad graphs and ensure
174  # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
175  true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
176  false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
177
178  # This modifies true_grad_graph and false_grad_graph.
179  _make_output_composite_tensors_match(_COND,
180                                       [true_grad_graph, false_grad_graph])
181
182  outputs = _build_cond(
183      if_op.inputs[0],
184      true_grad_graph,
185      false_grad_graph,
186      true_grad_inputs,
187      false_grad_inputs,
188      building_gradient=True,
189  )
190
191  # The predicate has no gradient.
192  return [None] + outputs
193
194
195def _build_cond(pred,
196                true_graph,
197                false_graph,
198                true_inputs,
199                false_inputs,
200                building_gradient,
201                name=None):
202  """Creates an If op from the specified predicate, branch functions and inputs.
203
204  Note that this modifies true_graph and false_graph to make the inputs match,
205  and to output all intermediates values so they're available for the gradient
206  computation.
207
208  true_graph and false_graph need not have the same input types, but they must
209  have the same output types.
210
211  Args:
212    pred: boolean Tensor
213    true_graph: FuncGraph
214    false_graph: FuncGraph
215    true_inputs: a list of Tensors to be passed to true_graph as input.
216    false_inputs: a list of Tensors to be passed to false_graph as input.
217    building_gradient: Whether this is a gradient If op.
218    name: the name for the If op.
219
220  Returns:
221    A list of Tensors which are the outputs of the If op. Does not include added
222    intermediate outputs.
223  """
224  _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
225  _check_same_outputs(_COND, [true_graph, false_graph])
226
227  # Add inputs to true_graph and false_graph to make them match. Note that
228  # this modifies true_graph and false_graph.
229  cond_inputs = _make_inputs_match([true_graph, false_graph],
230                                   [true_inputs, false_inputs])
231  # We do not output intermediates of the gradient If op since this is just
232  # for backwards compatibility with existing code.
233  if not building_gradient and util.output_all_intermediates():
234    # Add all intermediate tensors as function outputs so they're available for
235    # the gradient computation. Since the outputs of the two functions must
236    # match, we wrap all the intermediates in optionals. Each intermediate
237    # output will have a value iff its corresponding branch is taken.
238
239    true_intermediates = _get_intermediates(true_graph)
240    false_intermediates = _get_intermediates(false_graph)
241
242    # Wrap intermediates in optionals.
243    wrapped_true_intermediates = _wrap_intermediates(true_graph,
244                                                     true_intermediates)
245    wrapped_false_intermediates = _wrap_intermediates(false_graph,
246                                                      false_intermediates)
247
248    # Make outputs match by adding none optionals.
249    extra_true_outputs, extra_false_outputs = _make_intermediates_match(  # pylint: disable=unbalanced-tuple-unpacking
250        [true_graph, false_graph],
251        [wrapped_true_intermediates, wrapped_false_intermediates])
252
253    true_graph.outputs.extend(extra_true_outputs)
254    false_graph.outputs.extend(extra_false_outputs)
255    _check_same_outputs(_COND, [true_graph, false_graph])
256
257  # Create the If op.
258  with ops.control_dependencies(
259      list(true_graph.control_captures) + list(false_graph.control_captures)):
260    true_stateful_ops = [
261        op for op in true_graph.get_operations() if op._is_stateful
262    ]
263    false_stateful_ops = [
264        op for op in false_graph.get_operations() if op._is_stateful
265    ]
266    if (true_stateful_ops or false_stateful_ops):
267      op_fn = gen_functional_ops._if
268    else:
269      op_fn = gen_functional_ops.stateless_if
270
271    def _make_op(inputs):
272      if_op, tensors = util.get_op_and_outputs(op_fn(
273          pred,
274          inputs, [t.dtype for t in true_graph.outputs],
275          util.create_new_tf_function(true_graph),
276          util.create_new_tf_function(false_graph),
277          output_shapes=_get_output_shapes(true_graph.outputs,
278                                           false_graph.outputs),
279          name=name))
280      _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
281      # `if_op` is None if this is a `StatelessIf` op with no outputs.
282      if if_op is not None:
283        # The true and false graphs have already been created, and we need that
284        # to happen before we know which tensors will be captured and so whether
285        # to wrap the cond in a tf.function. Post-hoc mutation of the branch
286        # `outer_graph` properties seems like the only option if we want to
287        # conditionally wrap in a function.
288        true_graph.outer_graph = ops.get_default_graph()
289        false_graph.outer_graph = ops.get_default_graph()
290        if_op._true_graph = true_graph
291        if_op._false_graph = false_graph
292        util.maybe_set_lowering_attr(if_op)
293        util.maybe_propagate_compile_time_consts_in_xla(if_op)
294        _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
295        # Prevent fetching since the variant outputs can't be fetched directly.
296        if_op.graph.prevent_fetching(if_op)
297      return tensors
298    tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs)
299
300  # Return identities for each output of the If op, rather than the output of
301  # the If op directly. This makes pruning work if the output of cond() is
302  # fetched: the lowering pass converts the If outputs into IdentityN outputs,
303  # which if fetched will cause all ops in the taken branch to be run (since
304  # it takes all merge ops as input). After lowering, each output identity op
305  # will end up with only the appropriate merge op as input.
306  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
307  # correct output structure
308  tensors = [array_ops.identity(t) for t in tensors]
309
310  return _pack_sequence_as(true_graph.structured_outputs, tensors)
311
312
313def get_func_graphs(op):
314  """Returns `FuncGraph`s for the input op branches.
315
316  Args:
317    op: The If or Case Operation.
318
319  Returns:
320    A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches
321    for Case).
322  """
323
324  def _get_func_graph_for_branch(name_attr_list, cached_attr_name=None):
325    """Generates and returns a FuncGraph for the given branch."""
326    func_graph = None
327    if cached_attr_name is not None:
328      func_graph = getattr(op, cached_attr_name, None)
329    inputs = op.inputs[1:]  # First input is pred.
330    if func_graph is None:
331      input_shapes = [t.shape for t in inputs]
332      func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name)
333    for external_t, internal_t in zip(inputs, func_graph.inputs):
334      handle_data_util.copy_handle_data(external_t, internal_t)
335    func_graph.reset_captures(zip(inputs, func_graph.inputs))
336    # Link the op so that the gradient code can use it.
337    func_graph._forward_cond = op
338    return func_graph
339
340  if op.type in ["If", "StatelessIf"]:
341    return (_get_func_graph_for_branch(
342        op.get_attr("then_branch"), "_true_graph"),
343            _get_func_graph_for_branch(
344                op.get_attr("else_branch"), "_false_graph"))
345  elif op.type in ["Case", "StatelessCase"]:
346    return [_get_func_graph_for_branch(branch_fn, "_branch_graph_{}".format(i))
347            for i, branch_fn in enumerate(op.get_attr("branches"))]
348  else:
349    raise ValueError("Unsupported op type: {}".format(op.type))
350
351
352def _grad_fn(func_graph, grads):
353  """The gradient function for each conditional branch.
354
355  This function builds the gradient graph of the corresponding forward-pass
356  conditional branch in `func_graph`. This is done by differentiating
357  func_graph's outputs w.r.t. its inputs.
358
359  Args:
360    func_graph: FuncGraph. The corresponding forward-pass function.
361    grads: The list of input gradient Tensors.
362
363  Returns:
364    The output gradient Tensors.
365  """
366  # Filter out untrainable function outputs.
367  # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
368  # cause _GradientsHelper to raise an exception (e.g. the implementation
369  # doesn't expect 'ys' to contain boolean tensors).
370  assert len(func_graph.outputs) == len(grads)
371  ys = []
372  grad_ys = []
373  for y, grad_y in zip(func_graph.outputs, grads):
374    if not backprop_util.IsTrainable(y):
375      continue
376    ys.append(y)
377    grad_ys.append(grad_y)
378
379  # Build the gradient graph. Note that this builds the gradient computation of
380  # func_graph in the current graph, which requires capturing tensors from
381  # func_graph. The captured func_graph tensors are resolved to external tensors
382  # in _resolve_grad_inputs.
383  result = gradients_util._GradientsHelper(
384      ys, func_graph.inputs, grad_ys=grad_ys,
385      src_graph=func_graph)
386
387  return result
388
389
390def _create_grad_func(func_graph, grads, name):
391  """Returns the FuncGraph representation of _grad_fn."""
392  return func_graph_module.func_graph_from_py_func(
393      name,
394      lambda: _grad_fn(func_graph, grads), [], {},
395      func_graph=_CondGradFuncGraph(name, func_graph))
396
397
398def _resolve_grad_inputs(cond_graph, grad_graph):
399  """Returns the tensors to pass as inputs to `grad_graph`.
400
401  The `grad_graph` may have external references to
402  1. Its outer graph containing the input gradients. These references are kept
403     as is.
404  2. Tensors in the forward pass graph. These tensors may not be "live"
405     when the gradient is being computed. We replace such references by their
406     corresponding tensor in `cond_graph.outer_graph`. In the case of nested
407     control flow or functions, the gradient logic handling
408     `grad_graph.outer_graph` will make sure the tensor from
409     `cond_graph.outer_graph` is also correctly captured.
410
411  Args:
412    cond_graph: FuncGraph. The forward-pass function.
413    grad_graph: FuncGraph. The gradients function.
414
415  Returns:
416    A list of inputs tensors to be passed to grad_graph.
417  """
418  new_inputs = []
419
420  for t in grad_graph.external_captures:
421    # `t` must either be in `grad_graph.outer_graph` or in the forward
422    # `cond_graph`.
423    if t.graph != grad_graph.outer_graph:
424      assert t.graph == cond_graph
425      # `internal_captures` are not treated as intermediates and hence not added
426      # to If op outputs. So we get the outer tensor corresponding to those
427      # from the list of `external_captures`.
428      for i, output in enumerate(t.graph.outputs):
429        if output is t:
430          t = t.graph._forward_cond.outputs[i]
431          break
432      else:
433        for i, output in enumerate(t.graph.internal_captures):
434          if output is t:
435            t = t.graph.external_captures[i]
436            break
437        else:
438          raise ValueError("Could not find external tensor capture {tensor} in "
439                           "captures or outputs".format(tensor=t))
440
441      # Note: We rely on the capturing logic of the gradient If op graph to
442      # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2
443      # and while_v2 handle this while building their gradient functions.
444      assert t.graph == cond_graph.outer_graph
445    new_inputs.append(t)
446
447  return new_inputs
448
449
450def _get_intermediates(func_graph):
451  """Returns intermediate tensors of `func_graph` for gradient computation."""
452  intermediates = []
453  for op in func_graph.get_operations():
454    for t in op.outputs:
455      if t in func_graph.inputs: continue
456      if t in func_graph.outputs: continue
457      if t.dtype is dtypes.resource:
458        continue
459      # Accumulating mutexes can cause deadlock.
460      if op.type == "MutexLock":
461        continue
462      intermediates.append(t)
463  return intermediates
464
465
466def _make_intermediates_match(branch_graphs, branch_optionals):
467  """Returns new optionals lists that have matching signatures.
468
469  This is done by mirroring each list in the other using none optionals.
470  There is no merging of like optionals.
471
472  Args:
473    branch_graphs: `list` of `FuncGraph`.
474    branch_optionals: `list` of `list`s of optional `Tensor`s from other
475      branch_graphs
476
477  Returns:
478    A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the
479    same number of `Tensor`s, all of which will be optionals of the same
480    shape/type.
481  """
482  new_branch_optionals = []
483  # Since the intermediates are optionals with dtype variant, we only need
484  # enough room for the longest list of intermediates.
485  intermediates_size = max(len(o) for o in branch_optionals)
486  for i, branch_graph in enumerate(branch_graphs):
487    other_optionals = _create_none_optionals(
488        branch_graph, intermediates_size - len(branch_optionals[i]))
489    new_branch_optionals.append(branch_optionals[i] + other_optionals)
490  return new_branch_optionals
491
492
493def _make_intermediates_match_xla(branch_graphs, branch_intermediates):
494  """Like _make_intermediates_match but for the XLA case."""
495  new_branch_intermediates = []
496  for i, branch_graph in enumerate(branch_graphs):
497    other_fakeparams = _create_fakeparams(
498        branch_graph,
499        sum((bi for bi in branch_intermediates
500             if bi is not branch_intermediates[i]), []))
501    num_preceding = sum(len(bi) for bi in branch_intermediates[:i])
502    new_branch_intermediates.append(other_fakeparams[:num_preceding] +
503                                    branch_intermediates[i] +
504                                    other_fakeparams[num_preceding:])
505  return new_branch_intermediates
506
507
508def _make_inputs_match(branch_graphs, branch_inputs):
509  """Modifies branch_graphs so they have the same input signature.
510
511  This method reorders and/or adds parameters to each graph in branch_graphs so
512  they have the same input signature, and updates the 'inputs' and 'captured'
513  fields of each graph accordingly. It uses the input tensors from the outer
514  graph to avoid duplicating shared arguments.
515
516  Args:
517    branch_graphs: a `list` of `FuncGraph`
518    branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The
519      inputs for the corresponding graph in `branch_graphs`.
520
521  Returns:
522    A new list of Tensors from the outer graph that are the new inputs for each
523    branch_graph. This is a deduped version of `sum(branch_inputs)`.
524  """
525  assert len(branch_graphs) == len(branch_inputs)
526  added_inputs = set()
527  new_inputs = []
528  for branch_in in branch_inputs:
529    for tensor in branch_in:
530      tensor_id = ops.tensor_id(tensor)
531      if tensor_id not in added_inputs:
532        added_inputs.add(tensor_id)
533        new_inputs.append(tensor)
534
535  for branch_graph, branch_in in zip(branch_graphs, branch_inputs):
536    input_ids = [ops.tensor_id(t) for t in branch_in]
537    branch_input_to_param = dict(zip(input_ids, branch_graph.inputs))
538    input_list = []
539    for in_t in new_inputs:
540      param = branch_input_to_param.get(ops.tensor_id(in_t))
541      if param is None:
542        param = _create_dummy_input(branch_graph, in_t)
543      input_list.append(param)
544
545    branch_graph.inputs = input_list
546
547    # Rewrite the FuncGraphs' state to reflect the new inputs.
548    branch_graph.reset_captures(zip(new_inputs, branch_graph.inputs))
549
550  return new_inputs
551
552
553def _create_zeros_for_none_grads(forward_graphs, grad_graphs):
554  """Creates zeros for None out grads if at least one branch has non-None grad.
555
556  Args:
557    forward_graphs: List of forward FuncGraphs.
558    grad_graphs: List of grad FuncGraphs.
559  """
560  assert len(forward_graphs) == len(grad_graphs)
561  branch_outputs = [g.structured_outputs for g in grad_graphs]
562  num_outputs_per_branch = [len(outs) for outs in branch_outputs]
563  assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch
564  for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
565    if (any(t is None for t in branch_outs) and
566        any(t is not None for t in branch_outs)):
567      for branch_index, t in enumerate(branch_outs):
568        if t is None:
569          with grad_graphs[branch_index].as_default():
570            zeros = default_gradient.zeros_like(
571                forward_graphs[branch_index].inputs[output_idx])
572            grad_graphs[branch_index].structured_outputs[output_idx] = zeros
573
574  for grad_graph in grad_graphs:
575    grad_graph.outputs = [
576        t for t in func_graph_module.flatten(grad_graph.structured_outputs)
577        if t is not None
578    ]
579
580
581def _make_output_composite_tensors_match(op_type, branch_graphs):
582  """Modifies each branch_graph's outputs to have the same output signature.
583
584  Currently the only transformation implemented is turning a Tensor into an
585  equivalent IndexedSlices if the other branch returns an IndexedSlices.
586  Updates branch_graph.{outputs,structured_outputs} for each branch_graph in
587  branch_graphs.
588
589  Args:
590    op_type: _COND or _CASE
591    branch_graphs: `list` of `FuncGraph`
592
593  Raises:
594    TypeError: if a set of outputs cannot be rewritten.
595  """
596  # Note: since this is only used for gradient graphs, we do not expect the
597  # outputs to be structured (e.g. nested lists), and thus do not need to use
598  # nest.flatten, etc.
599  assert branch_graphs
600  branch_outputs = [g.structured_outputs for g in branch_graphs]
601  outputs_per_branch = list(len(outs) for outs in branch_outputs)
602  assert len(set(outputs_per_branch)) == 1, outputs_per_branch
603
604  for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
605    if len(set(type(out) for out in branch_outs)) == 1:
606      continue
607    if not any(isinstance(out, ops.IndexedSlices) for out in branch_outs):
608      continue
609    for branch_idx, branch_out in enumerate(branch_outs):
610      if isinstance(branch_out, ops.IndexedSlices):
611        continue
612      elif isinstance(branch_out, ops.Tensor):
613        with branch_graphs[branch_idx].as_default():
614          branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices(
615              branch_out)
616      else:
617        raise TypeError(
618            "Cannot reconcile {op_name} {output_idx}-th outputs:\n"
619            "  outputs from all branches: {outputs}".format(
620                op_name="tf.cond" if op_type == _COND else "tf.switch_case",
621                output_idx=output_idx,
622                outputs=branch_outs))
623
624  for branch_graph, branch_outs in zip(branch_graphs, branch_outputs):
625    branch_graph.structured_outputs = branch_outs
626    branch_graph.outputs = [
627        t for t in func_graph_module.flatten(branch_outs) if t is not None
628    ]
629
630
631def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
632  """Match dtype of IndexedSlices.indices in outputs of branch_graphs."""
633  assert branch_graphs
634  # Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`.
635  indexed_slice_indices = []
636  current_index = 0
637  # Note that this still contains Nones. We leave those in so that error
638  # messages contain the correct indices. We handle the Nones later when
639  # updating `current_index`.
640  branch_outputs_flat_with_composites = [
641      nest.flatten(branch_graph.structured_outputs, expand_composites=False)
642      for branch_graph in branch_graphs
643  ]
644  outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites]
645  assert len(set(outs_per_branch)) == 1, outs_per_branch
646  # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
647  for output_idx, branch_outs in enumerate(
648      zip(*branch_outputs_flat_with_composites)):
649    if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1:
650      raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n"
651                      "  branches returned: {outputs}".format(
652                          op_name="cond" if op_type == _COND else "switch_case",
653                          output_idx=output_idx,
654                          outputs=branch_outs))
655    if isinstance(branch_outs[0], ops.IndexedSlices):
656      # indices is the second component of the composite tensor.
657      indexed_slice_indices.append(current_index + 1)
658    if nest.is_sequence_or_composite(branch_outs[0]):
659      current_index += len(nest.flatten(branch_outs[0], expand_composites=True))
660    elif branch_outs[0] is not None:
661      # `FuncGraph.outputs` does not contain Nones so no need to update the
662      # counter in that case.
663      current_index += 1
664
665  if not indexed_slice_indices:
666    return
667
668  # `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus
669  # the Nones.
670  if current_index != len(branch_graphs[0].outputs):
671    raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n"
672                     "Expected: %i\n"
673                     "Actual: %i" %
674                     (current_index, len(branch_graphs[0].outputs)))
675
676  # Cast indices with mismatching types to int64.
677  for index in indexed_slice_indices:
678    if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64)
679           for bg in branch_graphs):
680      raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
681                      "Found: %s" %
682                      str([bg.outputs[index].dtype for bg in branch_graphs]))
683    if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1:
684      for branch_graph in branch_graphs:
685        if branch_graph.outputs[index].dtype == dtypes.int32:
686          with branch_graph.as_default():
687            branch_graph.outputs[index] = math_ops.cast(
688                branch_graph.outputs[index], dtypes.int64)
689
690  for branch_graph in branch_graphs:
691    branch_graph.structured_outputs = _pack_sequence_as(
692        branch_graph.structured_outputs, branch_graph.outputs)
693
694
695def _pack_sequence_as(structured_outputs, op_outputs):
696  """Packs the outputs of the gradient If/Case op.
697
698  The branch functions may contain None's in the list of `structured_outputs`.
699  `op_outputs` has those outputs missing. So we need to add those Nones to the
700  list of `op_outputs` and then pack it in the same structure as
701  `structured_outputs`.
702
703  Args:
704    structured_outputs: structured_outputs from one of the branch functions.
705    op_outputs: List of output tensors of the op.
706
707  Returns:
708    `op_outputs` packed like `structured_outputs`.
709  """
710  outputs_with_nones = []
711  counter = 0
712  for output in nest.flatten(structured_outputs, expand_composites=True):
713    if output is None:
714      outputs_with_nones.append(None)
715    else:
716      outputs_with_nones.append(op_outputs[counter])
717      counter += 1
718  return func_graph_module.pack_sequence_as(structured_outputs,
719                                            outputs_with_nones)
720
721
722def _wrap_intermediates(func_graph, intermediates):
723  with func_graph.as_default():
724    return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
725
726
727def _create_dummy_input(func_graph, template_tensor):
728  """Creates tensors in func_graph to represent template_tensors.
729
730  Args:
731    func_graph: FuncGraph.
732    template_tensor: a tensor in the outer graph.
733
734  Returns:
735    A tensor in func_graph.
736  """
737  with func_graph.as_default():
738    return array_ops.placeholder(
739        template_tensor.dtype, shape=template_tensor.shape)
740
741
742def _create_none_optionals(func_graph, n):
743  """Creates `n` `None` optionals in func_graph.
744
745  Args:
746    func_graph: FuncGraph.
747    n: `int` the number of `None` optionals to make.
748
749  Returns:
750    A list of tensors in func_graph.
751  """
752  with func_graph.as_default():
753    return [gen_dataset_ops.optional_none() for _ in range(n)]
754
755
756def _create_fakeparams(func_graph, template_tensors):
757  """Create FakeParams for the XLA case."""
758  with func_graph.as_default():
759    return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape)
760            for t in template_tensors]
761
762
763def _check_same_outputs(op_type, graphs):
764  """Raises an error if `graphs` have different outputs."""
765
766  def error(branch_idx, error_detail):
767    raise TypeError(
768        "{b0_name} and {bn_name} arguments to {op_name} must have the same "
769        "number, type, and overall structure of return values.\n"
770        "\n"
771        "{b0_name} output: {b0_out}\n"
772        "{bn_name} output: {bn_out}\n"
773        "\n"
774        "Error details:\n"
775        "{detail}".format(
776            b0_name="true_fn" if op_type == _COND else "branches[0]",
777            bn_name=("false_fn" if op_type == _COND else
778                     "branches[{}]".format(branch_idx)),
779            op_name="tf.cond" if op_type == _COND else "tf.switch_case",
780            b0_out=graphs[0].structured_outputs,
781            bn_out=graphs[branch_idx].structured_outputs,
782            detail=error_detail))
783
784  for b in range(1, len(graphs)):
785    try:
786      nest.assert_same_structure(
787          graphs[0].structured_outputs,
788          graphs[b].structured_outputs,
789          expand_composites=True)
790    except (ValueError, TypeError) as e:
791      error(b, str(e))
792
793    op_type_str = "cond" if op_type == _COND else "case"
794    if len(graphs[0].outputs) != len(graphs[b].outputs):
795      raise ValueError("Lengths of branch outputs of {op_type} must match.\n"
796                       "len(graphs[0].outputs): {len_0}\n"
797                       "len(graphs[{b}].outputs): {len_b}\n".format(
798                           op_type=op_type_str,
799                           len_0=len(graphs[0].outputs),
800                           b=b,
801                           len_b=len(graphs[b].outputs)))
802    for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs):
803      if b0_out.dtype != bn_out.dtype:
804        error(b, "%s and %s have different types" % (b0_out, bn_out))
805
806
807def _get_output_shapes(*branch_graph_outputs):
808  output_shapes = []
809  for out_by_branch in zip(*branch_graph_outputs):
810    shape = out_by_branch[0].shape
811    for other_out in out_by_branch[1:]:
812      shape = shape.most_specific_compatible_shape(other_out.shape)
813    output_shapes.append(shape)
814  return output_shapes
815
816
817def _copy_handle_data(external_tensors, *branch_graph_outputs):
818  """Combines shapes in handle data and sets metadata on `external_tensors`."""
819  for tensors in zip(external_tensors, *branch_graph_outputs):
820    external = tensors[0]
821    internal = tensors[1:]
822    internal_handle_data = []
823    for tensor in internal:
824      handle_data = handle_data_util.get_resource_handle_data(tensor)
825      # NOTE: Assumes handle data has only one ShapeAndType entry. It's
826      # unclear how to combine different lengths across branches.
827      if not handle_data.is_set or len(handle_data.shape_and_type) != 1:
828        break
829      internal_handle_data.append(handle_data)
830    else:  # There is handle data, so we need to combine it.
831      combined_shape = tensor_shape.TensorShape(None)
832      combined_dtype = None
833      for handle_data in internal_handle_data:
834        handle_shape = tensor_shape.TensorShape(
835            handle_data.shape_and_type[0].shape)
836        combined_shape = combined_shape.most_specific_compatible_shape(
837            handle_shape)
838        if combined_dtype is None:
839          combined_dtype = handle_data.shape_and_type[0].dtype
840        elif handle_data.shape_and_type[0].dtype != combined_dtype:
841          # Variants from different branches have different dtypes. The
842          # combined variant has no static dtype.
843          combined_dtype = types_pb2.DT_INVALID
844      combined_handle_data = internal_handle_data[0]
845      combined_handle_data.shape_and_type[0].shape.CopyFrom(
846          combined_shape.as_proto())
847      combined_handle_data.shape_and_type[0].dtype = combined_dtype
848      handle_data_util.set_handle_data(external, combined_handle_data)
849
850
851def verify_captures(op_type, branch_graphs):
852  """Verify that a branch's tensor is not accessed in another branch fn."""
853  # Note: It is technically not possible for lower-branch_index branches to
854  # capture tensors from higher-branch_index branches, because of the order of
855  # branch graph construction, but we check all for completeness and to
856  # guard against potential future changes.
857  other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)}
858  for i, branch_graph in enumerate(branch_graphs):
859    for t in branch_graph.external_captures:
860      if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs:
861        branch_names = ["true_fn", "false_fn"] if op_type == _COND else [
862            "branch {}".format(bi) for bi in range(len(branch_graphs))]
863        raise ValueError(
864            "Tensor {tname} in {b0name} is accessed from {b1name}.".format(
865                tname=t.name,
866                b0name=branch_names[other_branch_graphs[t.graph]],
867                b1name=branch_names[i]))
868
869
870class _CondGradFuncGraph(util.CondBranchFuncGraph):
871  """FuncGraph for the gradient function of the branch of an If op.
872
873  Handles wrapping and unwrapping intermediate values that are captured by the
874  gradient computation in optionals.
875
876  Attributes:
877    op_needs_rewrite: True if any intermediates were captured, meaning the
878      forward If op needs to be written to output the wrapped intermediates.
879  """
880
881  def __init__(self, name, forward_graph):
882    super(_CondGradFuncGraph, self).__init__(
883        name, collections=ops.get_default_graph()._collections)  # pylint: disable=protected-access
884    self.op_needs_rewrite = False
885    self._forward_graph = forward_graph
886    # Maps from forward intermediate tensor -> the unwrapped captured
887    # intermediate.
888    self._indirect_captures = {}
889    # Maps unwrapped intermediate -> optional-wrapped intermediate in the
890    # forward graph.
891    self._wrapped_intermediates = collections.OrderedDict()
892    # Raw intermediates captured from the forward graph. Populated iff we're in
893    # an XLA context.
894    self._xla_intermediates = []
895    # Maps forward intermediate constant valued tensor's id to the constant
896    # created in this graph for that tensor.
897    self._captured_constants = {}
898
899  @property
900  def wrapped_intermediates(self):
901    """The optional-wrapped intermediates captured from the forward graph."""
902    return list(self._wrapped_intermediates.values())
903
904  @property
905  def xla_intermediates(self):
906    """Raw intermediates captured from the forward graph if XLA is enabled."""
907    return self._xla_intermediates
908
909  def _capture_helper(self, tensor, name):
910    if (tensor.graph is not self._forward_graph or
911        any(tensor is t for t in self._forward_graph.inputs) or
912        any(tensor is t for t in self._forward_graph.outputs)):
913      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
914
915    tensor_id = ops.tensor_id(tensor)
916
917    # If `tensor` is a graph-building time constant, we create a constant with
918    # the same value in the backward graph instead of capturing it.
919    if tensor_id in self._captured_constants:
920      return self._captured_constants[tensor_id]
921    elif constant_op.is_constant(tensor):
922      self._captured_constants[tensor_id] = constant_op.constant(
923          tensor_util.constant_value(tensor), dtype=tensor.dtype)
924      return self._captured_constants[tensor_id]
925
926    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
927      # XLA does not yet support optionals, so capture intermediates directly.
928      # TODO(skyewm,jpienaar): can XLA support optionals?
929      if all(tensor is not capture for capture in self.external_captures):
930        self.xla_intermediates.append(tensor)
931        self.op_needs_rewrite = True
932      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
933
934    captured_tensor = self._indirect_captures.get(tensor_id)
935    if captured_tensor is not None:
936      return captured_tensor
937
938    # 'tensor' is an uncaptured intermediate in the forward graph.
939    # If it is not a resource, we wrap it in an optional in the forward graph
940    # and capture the optional normally. We then unwrap the captured optional
941    # value in the gradient graph to get the raw intermediate value.
942    # If it is a resource, we trace the resource up to the input in the forward
943    # graph and capture that.
944
945    if tensor.dtype == dtypes.resource:
946      # Index of the forward graph input corresponding to the resource tensor.
947      index = util.resource_input_index(
948          tensor.name, [t.name for t in self._forward_graph.inputs],
949          {op.name: op.node_def for op in self._forward_graph.get_operations()},
950          self._forward_graph._functions)
951      # This gets mapped to the corresponding If op input in
952      # `_resolve_grad_inputs`.
953      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
954          self._forward_graph.inputs[index], name)
955    else:
956      if tensor_id not in self._wrapped_intermediates:
957        # If the gradient has already been computed for this If op, 'tensor' may
958        # already be wrapped.
959        for consumer in tensor.consumers():
960          if (consumer.type == "OptionalFromValue" and
961              any(consumer.outputs[0] is output
962                  for output in self._forward_graph.outputs)):
963            optional = consumer.outputs[0]
964            break
965        else:
966          # 'tensor' hasn't been wrapped, do it now.
967          with self._forward_graph.as_default():
968            optional = gen_dataset_ops.optional_from_value([tensor])
969          self.op_needs_rewrite = True
970        self._wrapped_intermediates[tensor_id] = optional
971
972      optional = self._wrapped_intermediates[tensor_id]
973      captured_optional = super(_CondGradFuncGraph,
974                                self)._capture_helper(optional, name)
975      captured_tensor = gen_dataset_ops.optional_get_value(
976          captured_optional, [tensor.dtype], [tensor.shape])[0]
977
978    self._indirect_captures[tensor_id] = captured_tensor
979    return captured_tensor
980
981
982def indexed_case(branch_index,
983                 branch_fns,
984                 name="indexed_case",
985                 lower_using_switch_merge=None):
986  """Like conv_v2, except emits a Case op instead of an If."""
987  if isinstance(branch_index, int):
988    raise TypeError("branch_index must not be a Python int", branch_index)
989
990  with ops.name_scope(name) as scope:
991    branch_names = [
992        util.unique_fn_name(scope, "branch{}".format(b))
993        for b in range(len(branch_fns))
994    ]
995
996    # Automatic control dependencies are added in defuns, but not in v1
997    # graphs. Propagate that behavior here.
998    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
999    branch_index = ops.convert_to_tensor(branch_index, name="branch_index")
1000
1001    branch_graphs = []
1002    for branch_name, branch_fn in zip(branch_names, branch_fns):
1003      branch_graphs.append(
1004          func_graph_module.func_graph_from_py_func(
1005              branch_name,
1006              branch_fn,
1007              [],
1008              {},
1009              func_graph=util.CondBranchFuncGraph(
1010                  branch_name,
1011                  collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
1012              add_control_dependencies=add_control_dependencies,
1013              op_return_value=branch_index))
1014
1015    verify_captures(_CASE, branch_graphs)
1016    return _build_case(
1017        branch_index,
1018        branch_graphs, [g.external_captures for g in branch_graphs],
1019        name=scope,
1020        lower_using_switch_merge=lower_using_switch_merge)
1021
1022
1023@ops.RegisterGradient("Case")
1024@ops.RegisterGradient("StatelessCase")
1025def _CaseGrad(op, *grads):  # pylint: disable=invalid-name
1026  """The gradient of a Case op produced by tf.switch_case."""
1027  # Get the Case operator (this logic handles the case where op is a MockOp)
1028  case_op = op.outputs[0].op
1029  branch_graphs = get_func_graphs(case_op)
1030  assert branch_graphs
1031  # Note: op.graph != ops.get_default_graph() when we are computing the gradient
1032  # of a nested cond.
1033  for branch_graph in branch_graphs:
1034    assert branch_graph.outer_graph == case_op.graph
1035
1036  # Create grad functions that compute the gradient of the branch forward
1037  # graphs. These functions will capture tensors from the forward pass
1038  # functions.
1039  branch_grad_graphs = []
1040  for branch_graph in branch_graphs:
1041    branch_grad_graphs.append(
1042        _create_grad_func(branch_graph, grads,
1043                          util.unique_grad_fn_name(branch_graph.name)))
1044  # Replaces output None grads with zeros if at least one branch has non-None
1045  # grad at that index.
1046  _create_zeros_for_none_grads(branch_graphs, branch_grad_graphs)
1047
1048  if any(g.op_needs_rewrite for g in branch_grad_graphs):
1049    # Modify 'op' to output the intermediates needed by the grad functions. Note
1050    # that all needed intermediates are wrapped in optionals. Each optional
1051    # intermediate output will have a value iff its corresponding branch is
1052    # taken.
1053    # NOTE(bjp): if there are any active sessions, this modification to `op`
1054    # may make them unrunnable!
1055
1056    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
1057      # XLA does not yet support optionals, so output intermediates directly and
1058      # make them match via FakeParams, which can be converted to zeros in XLA.
1059      # TODO(bjp,jpienaar): can XLA support optionals?
1060      branches_intermediates = [
1061          branch_grad_graph.xla_intermediates
1062          for branch_grad_graph in branch_grad_graphs
1063      ]
1064      extra_branch_outputs = _make_intermediates_match_xla(
1065          branch_graphs, branches_intermediates)
1066    else:
1067      branch_intermediates = [
1068          g.wrapped_intermediates for g in branch_grad_graphs
1069      ]
1070      # Make outputs match by adding none optionals.
1071      extra_branch_outputs = _make_intermediates_match(branch_graphs,
1072                                                       branch_intermediates)
1073
1074    for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs):
1075      branch_graph.outputs.extend(extra_outputs)
1076    # TODO(bjp): indicate it's an internal bug if this fails.
1077    _check_same_outputs(_CASE, branch_graphs)
1078
1079    for branch_graph in branch_graphs:
1080      branch_graph.name += "_rewritten"
1081
1082    case_op._set_func_list_attr("branches", [
1083        util.create_new_tf_function(branch_graph)
1084        for branch_graph in branch_graphs
1085    ])
1086    case_op._set_type_list_attr("Tout", branch_graphs[0].output_types)
1087    case_op._set_shape_list_attr("output_shapes",
1088                                 branch_graphs[0].output_shapes)
1089    case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]],
1090                         [t.shape for t in extra_branch_outputs[0]])
1091
1092  # Resolve references to forward graph tensors in grad graphs and ensure
1093  # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
1094  branches_grad_inputs = [
1095      _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph,
1096      branch_grad_graph in zip(branch_graphs, branch_grad_graphs)
1097  ]
1098
1099  # This modifies the graphs in branch_grad_graphs.
1100  _make_output_composite_tensors_match(_CASE, branch_grad_graphs)
1101
1102  try:
1103    lowering = case_op._get_attr_bool("_lower_using_switch_merge")
1104  except errors_impl.NotFoundError:
1105    lowering = None
1106
1107  outputs = _build_case(
1108      case_op.inputs[0],
1109      branch_grad_graphs,
1110      branches_grad_inputs,
1111      name="gradient",
1112      lower_using_switch_merge=lowering)
1113
1114  # The predicate has no gradient.
1115  return [None] + outputs
1116
1117
1118def _build_case(branch_index,
1119                branch_graphs,
1120                branch_inputs,
1121                name=None,
1122                lower_using_switch_merge=None):
1123  """Creates an `Case` op from `branch_index`, branch graphs and inputs.
1124
1125  Note that this modifies `branch_graphs` to make the inputs match, and to
1126  output all intermediates values so they're available for the gradient
1127  computation.
1128
1129  `branch_graphs` need not have the same input types, but they must
1130  have the same output types.
1131
1132  Args:
1133    branch_index: integer Tensor
1134    branch_graphs: List of FuncGraph
1135    branch_inputs: List of lists of Tensors to be passed to corresponding
1136      branch_graph as input.
1137    name: the name for the Case op.
1138    lower_using_switch_merge: Lower this op using switch merge ops (optional).
1139
1140  Returns:
1141    A list of Tensors which are the outputs of the Case op. Does not include
1142    added intermediate outputs.
1143  """
1144  _make_indexed_slices_indices_types_match(_CASE, branch_graphs)
1145  _check_same_outputs(_CASE, branch_graphs)
1146
1147  # Add inputs to branch_graphs to make them match. Note that this modifies the
1148  # graphs in `branch_graphs`.
1149  case_inputs = _make_inputs_match(branch_graphs, branch_inputs)
1150
1151  stateful_ops = []
1152  for bg in branch_graphs:
1153    stateful_ops.extend([
1154        op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op)
1155    ])
1156
1157  if stateful_ops:
1158    op_fn = gen_functional_ops.case
1159  else:
1160    op_fn = gen_functional_ops.stateless_case
1161
1162  # Create the Case op.
1163  with ops.control_dependencies(
1164      sum((list(bg.control_captures) for bg in branch_graphs), [])):
1165
1166    def _make_op(inputs):
1167      case_op, tensors = util.get_op_and_outputs(op_fn(
1168          branch_index,
1169          inputs, [t.dtype for t in branch_graphs[0].outputs],
1170          [util.create_new_tf_function(g) for g in branch_graphs],
1171          output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
1172          name=name))
1173      _copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
1174      if case_op is not None:
1175        util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
1176        util.maybe_propagate_compile_time_consts_in_xla(case_op)
1177        _set_read_only_resource_inputs_attr(case_op, branch_graphs)
1178        # Prevent fetching since the variant outputs can't be fetched directly.
1179        case_op.graph.prevent_fetching(case_op)
1180
1181        # Store the branch graphs so they can be reused during the gradient
1182        # pass.
1183        for i, bg in enumerate(branch_graphs):
1184          bg.outer_graph = ops.get_default_graph()
1185          setattr(case_op, "_branch_graph_{}".format(i), bg)
1186
1187      return tensors
1188    tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs)
1189
1190  # Return identities for each output of the Case op, rather than the output of
1191  # the Case op directly. This makes pruning work if the output of switch_case()
1192  # is fetched: the lowering pass converts the Case outputs into IdentityN
1193  # outputs, which if fetched will cause all ops in the taken branch to be run
1194  # (since it takes all merge ops as input). After lowering, each output
1195  # identity op will end up with only the appropriate merge op as input.
1196  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
1197  # correct output structure
1198  tensors = [array_ops.identity(t) for t in tensors]
1199
1200  return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)
1201
1202
1203def _set_read_only_resource_inputs_attr(op, branch_graphs):
1204  """Sets the list of resource inputs which are read-only.
1205
1206  This is used by AutomaticControlDependencies.
1207
1208  Args:
1209    op: If or Case Operation.
1210    branch_graphs: List of branch FuncGraphs.
1211  """
1212  # The first entry in `op.inputs` is the predicate which is not passed to
1213  # branch graphs so len(branch_graph[i].inputs) == len(op.inputs) - 1.
1214  read_only_indices = set(range(len(op.inputs) - 1))
1215  for branch_graph in branch_graphs:
1216    assert len(branch_graph.inputs) == len(op.inputs) - 1, "should never happen"
1217    if not read_only_indices:
1218      break
1219    branch_read_only_indices = acd.get_read_only_resource_input_indices_graph(
1220        branch_graph)
1221    read_only_indices = read_only_indices.intersection(branch_read_only_indices)
1222  # Convert indices in `branch_graphs[i].inputs` to `op.inputs`.
1223  read_only_indices = [i + 1 for i in read_only_indices]
1224  ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
1225                        sorted(read_only_indices))
1226