• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""cond_v2 and gradient.
16
17This is a version of cond that emits a single If op, as well as the gradient
18function for If ops produced by cond_v2. This will eventually replace the
19current tf.cond implementation once it reaches feature and performance parity.
20"""
21
22import collections
23
24from tensorflow.core.framework import types_pb2
25from tensorflow.python.eager import backprop_util
26from tensorflow.python.framework import auto_control_deps
27from tensorflow.python.framework import auto_control_deps_utils as acd
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors_impl
31from tensorflow.python.framework import func_graph as func_graph_module
32from tensorflow.python.framework import indexed_slices
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.framework import type_spec
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_util
39from tensorflow.python.ops import control_flow_util_v2 as util
40from tensorflow.python.ops import default_gradient
41from tensorflow.python.ops import gen_dataset_ops
42from tensorflow.python.ops import gen_functional_ops
43from tensorflow.python.ops import gradients_util
44from tensorflow.python.ops import handle_data_util
45from tensorflow.python.ops import math_ops
46from tensorflow.python.util import nest
47
48
49# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify
50# that they aren't part of the official public API. These protected members
51# often need to be used by implementation code however. Rather than litter the
52# code with pylint comments, we ignore protected access violations for
53# readability.
54# pylint: disable=protected-access
55
56_COND = 1
57_CASE = 2
58
59
60def cond_v2(pred, true_fn, false_fn, name="cond"):
61  """Like tf.cond, except emits a single If op."""
62  if isinstance(pred, bool):
63    raise TypeError("pred must not be a Python bool", pred)
64
65  if not name:
66    name = "cond"
67
68  with ops.name_scope(name) as scope:
69    true_name = util.unique_fn_name(scope, "true")
70    false_name = util.unique_fn_name(scope, "false")
71
72    # Automatic control dependencies are added in defuns, but not in v1
73    # graphs. Propagate that behavior here.
74    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
75    pred = ops.convert_to_tensor(pred)
76    if (tensor_util.is_tf_type(pred) and
77        (pred.shape.dims is None or pred.shape.dims)):
78      pred = array_ops.squeeze_v2(pred)
79
80    true_graph = func_graph_module.func_graph_from_py_func(
81        true_name,
82        true_fn, [], {},
83        func_graph=util.CondBranchFuncGraph(
84            true_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
85        add_control_dependencies=add_control_dependencies,
86        op_return_value=pred)
87    false_graph = func_graph_module.func_graph_from_py_func(
88        false_name,
89        false_fn, [], {},
90        func_graph=util.CondBranchFuncGraph(
91            false_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
92        add_control_dependencies=add_control_dependencies,
93        op_return_value=pred)
94
95    verify_captures(_COND, [true_graph, false_graph])
96    return _build_cond(
97        pred,
98        true_graph,
99        false_graph,
100        true_graph.external_captures,
101        false_graph.external_captures,
102        building_gradient=False,
103        name=scope)
104
105
106@ops.RegisterGradient("StatelessIf")
107@ops.RegisterGradient("If")
108def _IfGrad(op, *grads):  # pylint: disable=invalid-name
109  """The gradient of an If op produced by cond_v2."""
110  # Get the if operator (this logic handles the case where op is a MockOp)
111  if_op = op.outputs[0].op
112  true_graph, false_graph = get_func_graphs(if_op)
113  # Note: op.graph != ops.get_default_graph() when we are computing the gradient
114  # of a nested cond.
115  assert true_graph.outer_graph == if_op.graph
116  assert false_graph.outer_graph == if_op.graph
117
118  # Create grad functions that compute the gradient of the true/false forward
119  # graphs. These functions will capture tensors from the forward pass
120  # functions.
121  true_grad_graph = _create_grad_func(
122      true_graph, grads, util.unique_grad_fn_name(true_graph.name))
123  false_grad_graph = _create_grad_func(
124      false_graph, grads, util.unique_grad_fn_name(false_graph.name))
125
126  # Replaces output None grads with zeros if at least one branch has non-None
127  # grad at that index.
128  _create_zeros_for_none_grads([true_graph, false_graph],
129                               [true_grad_graph, false_grad_graph])
130
131  if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite):
132    # Modify 'op' to output the intermediates needed by the grad functions. Note
133    # that all needed intermediates are wrapped in optionals. Each optional
134    # intermediate output will have a value iff its corresponding branch is
135    # taken.
136    # NOTE(skyewm): if there are any active sessions, this modification to `op`
137    # may make them unrunnable!
138
139    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
140      # XLA does not yet support optionals, so output intermediates directly and
141      # make them match via FakeParams, which can be converted to zeros in XLA.
142      # TODO(skyewm,jpienaar): can XLA support optionals?
143      true_intermediates = true_grad_graph.xla_intermediates
144      false_intermediates = false_grad_graph.xla_intermediates
145      extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
146          [true_graph, false_graph], [true_intermediates, false_intermediates])
147    else:
148      true_intermediates = true_grad_graph.wrapped_intermediates
149      false_intermediates = false_grad_graph.wrapped_intermediates
150      # Make outputs match by adding none optionals.
151      extra_true_outputs, extra_false_outputs = _make_intermediates_match(
152          [true_graph, false_graph], [true_intermediates, false_intermediates])
153
154    true_graph.outputs.extend(extra_true_outputs)
155    false_graph.outputs.extend(extra_false_outputs)
156    # TODO(skyewm): indicate it's an internal bug if this fails.
157    _check_same_outputs(_COND, [true_graph, false_graph])
158
159    true_graph.name += "_rewritten"
160    false_graph.name += "_rewritten"
161
162    if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph))
163    if_op._set_func_attr("else_branch",
164                         util.create_new_tf_function(false_graph))
165    if_op._set_type_list_attr("Tout", true_graph.output_types)
166    if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes)
167    if_op._add_outputs(
168        [t.dtype for t in extra_true_outputs],
169        [t.shape for t in extra_true_outputs])
170
171  # Resolve references to forward graph tensors in grad graphs and ensure
172  # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
173  true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
174  false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
175
176  # This modifies true_grad_graph and false_grad_graph.
177  _make_output_composite_tensors_match(_COND,
178                                       [true_grad_graph, false_grad_graph])
179
180  outputs = _build_cond(
181      if_op.inputs[0],
182      true_grad_graph,
183      false_grad_graph,
184      true_grad_inputs,
185      false_grad_inputs,
186      building_gradient=True,
187  )
188
189  # The predicate has no gradient.
190  return [None] + outputs
191
192
193def _build_cond(pred,
194                true_graph,
195                false_graph,
196                true_inputs,
197                false_inputs,
198                building_gradient,
199                name=None):
200  """Creates an If op from the specified predicate, branch functions and inputs.
201
202  Note that this modifies true_graph and false_graph to make the inputs match,
203  and to output all intermediates values so they're available for the gradient
204  computation.
205
206  true_graph and false_graph need not have the same input types, but they must
207  have the same output types.
208
209  Args:
210    pred: boolean Tensor
211    true_graph: FuncGraph
212    false_graph: FuncGraph
213    true_inputs: a list of Tensors to be passed to true_graph as input.
214    false_inputs: a list of Tensors to be passed to false_graph as input.
215    building_gradient: Whether this is a gradient If op.
216    name: the name for the If op.
217
218  Returns:
219    A list of Tensors which are the outputs of the If op. Does not include added
220    intermediate outputs.
221  """
222  _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
223  _check_same_outputs(_COND, [true_graph, false_graph])
224
225  # Add inputs to true_graph and false_graph to make them match. Note that
226  # this modifies true_graph and false_graph.
227  cond_inputs = _make_inputs_match([true_graph, false_graph],
228                                   [true_inputs, false_inputs])
229  # We do not output intermediates of the gradient If op since this is just
230  # for backwards compatibility with existing code.
231  if not building_gradient and util.output_all_intermediates():
232    # Add all intermediate tensors as function outputs so they're available for
233    # the gradient computation. Since the outputs of the two functions must
234    # match, we wrap all the intermediates in optionals. Each intermediate
235    # output will have a value iff its corresponding branch is taken.
236
237    true_intermediates = _get_intermediates(true_graph)
238    false_intermediates = _get_intermediates(false_graph)
239
240    # Wrap intermediates in optionals.
241    wrapped_true_intermediates = _wrap_intermediates(true_graph,
242                                                     true_intermediates)
243    wrapped_false_intermediates = _wrap_intermediates(false_graph,
244                                                      false_intermediates)
245
246    # Make outputs match by adding none optionals.
247    extra_true_outputs, extra_false_outputs = _make_intermediates_match(  # pylint: disable=unbalanced-tuple-unpacking
248        [true_graph, false_graph],
249        [wrapped_true_intermediates, wrapped_false_intermediates])
250
251    true_graph.outputs.extend(extra_true_outputs)
252    false_graph.outputs.extend(extra_false_outputs)
253    _check_same_outputs(_COND, [true_graph, false_graph])
254
255  # Create the If op.
256  with ops.control_dependencies(
257      list(true_graph.control_captures) + list(false_graph.control_captures)):
258    true_stateful_ops = [
259        op for op in true_graph.get_operations() if op._is_stateful
260    ]
261    false_stateful_ops = [
262        op for op in false_graph.get_operations() if op._is_stateful
263    ]
264    if (true_stateful_ops or false_stateful_ops):
265      op_fn = gen_functional_ops._if
266    else:
267      op_fn = gen_functional_ops.stateless_if
268
269    def _make_op(inputs):
270      if_op, tensors = util.get_op_and_outputs(op_fn(
271          pred,
272          inputs, [t.dtype for t in true_graph.outputs],
273          util.create_new_tf_function(true_graph),
274          util.create_new_tf_function(false_graph),
275          output_shapes=_get_output_shapes(true_graph.outputs,
276                                           false_graph.outputs),
277          name=name))
278      _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
279      # `if_op` is None if this is a `StatelessIf` op with no outputs.
280      if if_op is not None:
281        # The true and false graphs have already been created, and we need that
282        # to happen before we know which tensors will be captured and so whether
283        # to wrap the cond in a tf.function. Post-hoc mutation of the branch
284        # `outer_graph` properties seems like the only option if we want to
285        # conditionally wrap in a function.
286        true_graph.outer_graph = ops.get_default_graph()
287        false_graph.outer_graph = ops.get_default_graph()
288        if_op._true_graph = true_graph
289        if_op._false_graph = false_graph
290        util.maybe_set_lowering_attr(if_op)
291        util.maybe_propagate_compile_time_consts_in_xla(if_op)
292        _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
293        # Prevent fetching since the variant outputs can't be fetched directly.
294        if_op.graph.prevent_fetching(if_op)
295      return tensors
296    tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs)
297
298  # Return identities for each output of the If op, rather than the output of
299  # the If op directly. This makes pruning work if the output of cond() is
300  # fetched: the lowering pass converts the If outputs into IdentityN outputs,
301  # which if fetched will cause all ops in the taken branch to be run (since
302  # it takes all merge ops as input). After lowering, each output identity op
303  # will end up with only the appropriate merge op as input.
304  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
305  # correct output structure
306  tensors = [array_ops.identity(t) for t in tensors]
307
308  structured_output_specs = _get_compatible_structured_output_specs(true_graph,
309                                                                    false_graph)
310  return _pack_sequence_as(structured_output_specs, tensors)
311
312
313def get_func_graphs(op):
314  """Returns `FuncGraph`s for the input op branches.
315
316  Args:
317    op: The If or Case Operation.
318
319  Returns:
320    A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches
321    for Case).
322  """
323
324  def _get_func_graph_for_branch(name_attr_list, cached_attr_name=None):
325    """Generates and returns a FuncGraph for the given branch."""
326    func_graph = None
327    if cached_attr_name is not None:
328      func_graph = getattr(op, cached_attr_name, None)
329    inputs = op.inputs[1:]  # First input is pred.
330    if func_graph is None:
331      input_shapes = [t.shape for t in inputs]
332      func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name)
333    for external_t, internal_t in zip(inputs, func_graph.inputs):
334      handle_data_util.copy_handle_data(external_t, internal_t)
335    func_graph.reset_captures(zip(inputs, func_graph.inputs))
336    # Link the op so that the gradient code can use it.
337    func_graph._forward_cond = op
338    return func_graph
339
340  if op.type in ["If", "StatelessIf"]:
341    return (_get_func_graph_for_branch(
342        op.get_attr("then_branch"), "_true_graph"),
343            _get_func_graph_for_branch(
344                op.get_attr("else_branch"), "_false_graph"))
345  elif op.type in ["Case", "StatelessCase"]:
346    return [_get_func_graph_for_branch(branch_fn, "_branch_graph_{}".format(i))
347            for i, branch_fn in enumerate(op.get_attr("branches"))]
348  else:
349    raise ValueError("Unsupported op type: {}".format(op.type))
350
351
352def _get_compatible_structured_output_specs(true_graph, false_graph):
353  """Returns the most specific compatible specs of graph structured outputs."""
354  return nest.map_structure(_get_compatible_spec,
355                            true_graph.structured_outputs,
356                            false_graph.structured_outputs)
357
358
359def _get_compatible_spec(value_or_spec1, value_or_spec2):
360  """Returns the most specific compatible spec.
361
362  Args:
363    value_or_spec1: A TypeSpecs or a value that has a defined TypeSpec.
364    value_or_spec2: A TypeSpecs or a value that has a defined TypeSpec.
365
366  Returns:
367    The most specific compatible TypeSpecs of the input.
368
369  Raises:
370    ValueError: If value_or_spec1 is not compatible with value_or_spec2.
371  """
372  spec1 = _get_spec_for(value_or_spec1)
373  spec2 = _get_spec_for(value_or_spec2)
374
375  # pylint: disable=protected-access
376  common = spec1._without_tensor_names().most_specific_common_supertype(
377      [spec2._without_tensor_names()])
378  if common is None:
379    raise TypeError(f"No common supertype of {spec1} and {spec2}.")
380  return common
381
382
383def _get_spec_for(value_or_spec):
384  """Returns TypeSpec of a value or itself if it is a TypeSpec already."""
385  if isinstance(value_or_spec, type_spec.TypeSpec):
386    return value_or_spec
387  return type_spec.type_spec_from_value(value_or_spec)
388
389
390def _grad_fn(func_graph, grads):
391  """The gradient function for each conditional branch.
392
393  This function builds the gradient graph of the corresponding forward-pass
394  conditional branch in `func_graph`. This is done by differentiating
395  func_graph's outputs w.r.t. its inputs.
396
397  Args:
398    func_graph: FuncGraph. The corresponding forward-pass function.
399    grads: The list of input gradient Tensors.
400
401  Returns:
402    The output gradient Tensors.
403  """
404  # Filter out untrainable function outputs.
405  # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
406  # cause _GradientsHelper to raise an exception (e.g. the implementation
407  # doesn't expect 'ys' to contain boolean tensors).
408  assert len(func_graph.outputs) == len(grads)
409  ys = []
410  grad_ys = []
411  for y, grad_y in zip(func_graph.outputs, grads):
412    if not backprop_util.IsTrainable(y):
413      continue
414    ys.append(y)
415    grad_ys.append(grad_y)
416
417  # Build the gradient graph. Note that this builds the gradient computation of
418  # func_graph in the current graph, which requires capturing tensors from
419  # func_graph. The captured func_graph tensors are resolved to external tensors
420  # in _resolve_grad_inputs.
421  result = gradients_util._GradientsHelper(
422      ys, func_graph.inputs, grad_ys=grad_ys,
423      src_graph=func_graph)
424
425  return result
426
427
428def _create_grad_func(func_graph, grads, name):
429  """Returns the FuncGraph representation of _grad_fn."""
430  return func_graph_module.func_graph_from_py_func(
431      name,
432      lambda: _grad_fn(func_graph, grads), [], {},
433      func_graph=_CondGradFuncGraph(name, func_graph))
434
435
436def _resolve_grad_inputs(cond_graph, grad_graph):
437  """Returns the tensors to pass as inputs to `grad_graph`.
438
439  The `grad_graph` may have external references to
440  1. Its outer graph containing the input gradients. These references are kept
441     as is.
442  2. Tensors in the forward pass graph. These tensors may not be "live"
443     when the gradient is being computed. We replace such references by their
444     corresponding tensor in `cond_graph.outer_graph`. In the case of nested
445     control flow or functions, the gradient logic handling
446     `grad_graph.outer_graph` will make sure the tensor from
447     `cond_graph.outer_graph` is also correctly captured.
448
449  Args:
450    cond_graph: FuncGraph. The forward-pass function.
451    grad_graph: FuncGraph. The gradients function.
452
453  Returns:
454    A list of inputs tensors to be passed to grad_graph.
455  """
456  new_inputs = []
457
458  for t in grad_graph.external_captures:
459    # `t` must either be in `grad_graph.outer_graph` or in the forward
460    # `cond_graph`.
461    if t.graph != grad_graph.outer_graph:
462      assert t.graph == cond_graph
463      # `internal_captures` are not treated as intermediates and hence not added
464      # to If op outputs. So we get the outer tensor corresponding to those
465      # from the list of `external_captures`.
466      for i, output in enumerate(t.graph.outputs):
467        if output is t:
468          t = t.graph._forward_cond.outputs[i]
469          break
470      else:
471        for i, output in enumerate(t.graph.internal_captures):
472          if output is t:
473            t = t.graph.external_captures[i]
474            break
475        else:
476          raise ValueError("Could not find external tensor capture {tensor} in "
477                           "captures or outputs".format(tensor=t))
478
479      # Note: We rely on the capturing logic of the gradient If op graph to
480      # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2
481      # and while_v2 handle this while building their gradient functions.
482      assert t.graph == cond_graph.outer_graph
483    new_inputs.append(t)
484
485  return new_inputs
486
487
488def _get_intermediates(func_graph):
489  """Returns intermediate tensors of `func_graph` for gradient computation."""
490  intermediates = []
491  for op in func_graph.get_operations():
492    for t in op.outputs:
493      if t in func_graph.inputs: continue
494      if t in func_graph.outputs: continue
495      if t.dtype is dtypes.resource:
496        continue
497      # Accumulating mutexes can cause deadlock.
498      if op.type == "MutexLock":
499        continue
500      intermediates.append(t)
501  return intermediates
502
503
504def _make_intermediates_match(branch_graphs, branch_optionals):
505  """Returns new optionals lists that have matching signatures.
506
507  This is done by mirroring each list in the other using none optionals.
508  There is no merging of like optionals.
509
510  Args:
511    branch_graphs: `list` of `FuncGraph`.
512    branch_optionals: `list` of `list`s of optional `Tensor`s from other
513      branch_graphs
514
515  Returns:
516    A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the
517    same number of `Tensor`s, all of which will be optionals of the same
518    shape/type.
519  """
520  new_branch_optionals = []
521  # Since the intermediates are optionals with dtype variant, we only need
522  # enough room for the longest list of intermediates.
523  intermediates_size = max(len(o) for o in branch_optionals)
524  for i, branch_graph in enumerate(branch_graphs):
525    other_optionals = _create_none_optionals(
526        branch_graph, intermediates_size - len(branch_optionals[i]))
527    new_branch_optionals.append(branch_optionals[i] + other_optionals)
528  return new_branch_optionals
529
530
531def _make_intermediates_match_xla(branch_graphs, branch_intermediates):
532  """Like _make_intermediates_match but for the XLA case."""
533  new_branch_intermediates = []
534  for i, branch_graph in enumerate(branch_graphs):
535    other_fakeparams = _create_fakeparams(
536        branch_graph,
537        sum((bi for bi in branch_intermediates
538             if bi is not branch_intermediates[i]), []))
539    num_preceding = sum(len(bi) for bi in branch_intermediates[:i])
540    new_branch_intermediates.append(other_fakeparams[:num_preceding] +
541                                    branch_intermediates[i] +
542                                    other_fakeparams[num_preceding:])
543  return new_branch_intermediates
544
545
546def _make_inputs_match(branch_graphs, branch_inputs):
547  """Modifies branch_graphs so they have the same input signature.
548
549  This method reorders and/or adds parameters to each graph in branch_graphs so
550  they have the same input signature, and updates the 'inputs' and 'captured'
551  fields of each graph accordingly. It uses the input tensors from the outer
552  graph to avoid duplicating shared arguments.
553
554  Args:
555    branch_graphs: a `list` of `FuncGraph`
556    branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The
557      inputs for the corresponding graph in `branch_graphs`.
558
559  Returns:
560    A new list of Tensors from the outer graph that are the new inputs for each
561    branch_graph. This is a deduped version of `sum(branch_inputs)`.
562  """
563  assert len(branch_graphs) == len(branch_inputs)
564  added_inputs = set()
565  new_inputs = []
566  for branch_in in branch_inputs:
567    for tensor in branch_in:
568      tensor_id = ops.tensor_id(tensor)
569      if tensor_id not in added_inputs:
570        added_inputs.add(tensor_id)
571        new_inputs.append(tensor)
572
573  for branch_graph, branch_in in zip(branch_graphs, branch_inputs):
574    input_ids = [ops.tensor_id(t) for t in branch_in]
575    branch_input_to_param = dict(zip(input_ids, branch_graph.inputs))
576    input_list = []
577    for in_t in new_inputs:
578      param = branch_input_to_param.get(ops.tensor_id(in_t))
579      if param is None:
580        param = _create_dummy_input(branch_graph, in_t)
581      input_list.append(param)
582
583    branch_graph.inputs = input_list
584
585    # Rewrite the FuncGraphs' state to reflect the new inputs.
586    branch_graph.reset_captures(zip(new_inputs, branch_graph.inputs))
587
588  return new_inputs
589
590
591def _create_zeros_for_none_grads(forward_graphs, grad_graphs):
592  """Creates zeros for None out grads if at least one branch has non-None grad.
593
594  Args:
595    forward_graphs: List of forward FuncGraphs.
596    grad_graphs: List of grad FuncGraphs.
597  """
598  assert len(forward_graphs) == len(grad_graphs)
599  branch_outputs = [g.structured_outputs for g in grad_graphs]
600  num_outputs_per_branch = [len(outs) for outs in branch_outputs]
601  assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch
602  for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
603    if (any(t is None for t in branch_outs) and
604        any(t is not None for t in branch_outs)):
605      for branch_index, t in enumerate(branch_outs):
606        if t is None:
607          with grad_graphs[branch_index].as_default():
608            zeros = default_gradient.zeros_like(
609                forward_graphs[branch_index].inputs[output_idx])
610            grad_graphs[branch_index].structured_outputs[output_idx] = zeros
611
612  for grad_graph in grad_graphs:
613    grad_graph.outputs = [
614        t for t in func_graph_module.flatten(grad_graph.structured_outputs)
615        if t is not None
616    ]
617
618
619def _make_output_composite_tensors_match(op_type, branch_graphs):
620  """Modifies each branch_graph's outputs to have the same output signature.
621
622  Currently the only transformation implemented is turning a Tensor into an
623  equivalent IndexedSlices if the other branch returns an IndexedSlices.
624  Updates branch_graph.{outputs,structured_outputs} for each branch_graph in
625  branch_graphs.
626
627  Args:
628    op_type: _COND or _CASE
629    branch_graphs: `list` of `FuncGraph`
630
631  Raises:
632    TypeError: if a set of outputs cannot be rewritten.
633  """
634  # Note: since this is only used for gradient graphs, we do not expect the
635  # outputs to be structured (e.g. nested lists), and thus do not need to use
636  # nest.flatten, etc.
637  assert branch_graphs
638  branch_outputs = [g.structured_outputs for g in branch_graphs]
639  outputs_per_branch = list(len(outs) for outs in branch_outputs)
640  assert len(set(outputs_per_branch)) == 1, outputs_per_branch
641
642  for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
643    if len(set(type(out) for out in branch_outs)) == 1:
644      continue
645    if not any(
646        isinstance(out, indexed_slices.IndexedSlices) for out in branch_outs):
647      continue
648    for branch_idx, branch_out in enumerate(branch_outs):
649      if isinstance(branch_out, indexed_slices.IndexedSlices):
650        continue
651      elif isinstance(branch_out, ops.Tensor):
652        with branch_graphs[branch_idx].as_default():
653          branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices(
654              branch_out)
655      else:
656        raise TypeError(
657            "Cannot reconcile {op_name} {output_idx}-th outputs:\n"
658            "  outputs from all branches: {outputs}".format(
659                op_name="tf.cond" if op_type == _COND else "tf.switch_case",
660                output_idx=output_idx,
661                outputs=branch_outs))
662
663  for branch_graph, branch_outs in zip(branch_graphs, branch_outputs):
664    branch_graph.structured_outputs = branch_outs
665    branch_graph.outputs = [
666        t for t in func_graph_module.flatten(branch_outs) if t is not None
667    ]
668
669
670def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
671  """Match dtype of IndexedSlices.indices in outputs of branch_graphs."""
672  assert branch_graphs
673  # Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`.
674  indexed_slice_indices = []
675  current_index = 0
676  # Note that this still contains Nones. We leave those in so that error
677  # messages contain the correct indices. We handle the Nones later when
678  # updating `current_index`.
679  branch_outputs_flat_with_composites = [
680      nest.flatten(branch_graph.structured_outputs, expand_composites=False)
681      for branch_graph in branch_graphs
682  ]
683  outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites]
684  assert len(set(outs_per_branch)) == 1, outs_per_branch
685  # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
686  for output_idx, branch_outs in enumerate(
687      zip(*branch_outputs_flat_with_composites)):
688    if len(
689        set(
690            isinstance(out, indexed_slices.IndexedSlices)
691            for out in branch_outs)) != 1:
692      raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n"
693                      "  branches returned: {outputs}".format(
694                          op_name="cond" if op_type == _COND else "switch_case",
695                          output_idx=output_idx,
696                          outputs=branch_outs))
697    if isinstance(branch_outs[0], indexed_slices.IndexedSlices):
698      # indices is the second component of the composite tensor.
699      indexed_slice_indices.append(current_index + 1)
700    if nest.is_nested_or_composite(branch_outs[0]):
701      current_index += len(nest.flatten(branch_outs[0], expand_composites=True))
702    elif branch_outs[0] is not None:
703      # `FuncGraph.outputs` does not contain Nones so no need to update the
704      # counter in that case.
705      current_index += 1
706
707  if not indexed_slice_indices:
708    return
709
710  # `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus
711  # the Nones.
712  if current_index != len(branch_graphs[0].outputs):
713    raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n"
714                     "Expected: %i\n"
715                     "Actual: %i" %
716                     (current_index, len(branch_graphs[0].outputs)))
717
718  # Cast indices with mismatching types to int64.
719  for index in indexed_slice_indices:
720    if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64)
721           for bg in branch_graphs):
722      raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
723                      "Found: %s" %
724                      str([bg.outputs[index].dtype for bg in branch_graphs]))
725    if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1:
726      for branch_graph in branch_graphs:
727        if branch_graph.outputs[index].dtype == dtypes.int32:
728          with branch_graph.as_default():
729            branch_graph.outputs[index] = math_ops.cast(
730                branch_graph.outputs[index], dtypes.int64)
731
732  for branch_graph in branch_graphs:
733    branch_graph.structured_outputs = _pack_sequence_as(
734        branch_graph.structured_outputs, branch_graph.outputs)
735
736
737def _pack_sequence_as(structured_outputs, op_outputs):
738  """Packs the outputs of the gradient If/Case op.
739
740  The branch functions may contain None's in the list of `structured_outputs`.
741  `op_outputs` has those outputs missing. So we need to add those Nones to the
742  list of `op_outputs` and then pack it in the same structure as
743  `structured_outputs`.
744
745  Args:
746    structured_outputs: structured_outputs from one of the branch functions.
747    op_outputs: List of output tensors of the op.
748
749  Returns:
750    `op_outputs` packed like `structured_outputs`.
751  """
752  outputs_with_nones = []
753  counter = 0
754  for output in nest.flatten(structured_outputs, expand_composites=True):
755    if output is None:
756      outputs_with_nones.append(None)
757    else:
758      outputs_with_nones.append(op_outputs[counter])
759      counter += 1
760  return func_graph_module.pack_sequence_as(structured_outputs,
761                                            outputs_with_nones)
762
763
764def _wrap_intermediates(func_graph, intermediates):
765  with func_graph.as_default():
766    return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
767
768
769def _create_dummy_input(func_graph, template_tensor):
770  """Creates tensors in func_graph to represent template_tensors.
771
772  Args:
773    func_graph: FuncGraph.
774    template_tensor: a tensor in the outer graph.
775
776  Returns:
777    A tensor in func_graph.
778  """
779  with func_graph.as_default():
780    return array_ops.placeholder(
781        template_tensor.dtype, shape=template_tensor.shape)
782
783
784def _create_none_optionals(func_graph, n):
785  """Creates `n` `None` optionals in func_graph.
786
787  Args:
788    func_graph: FuncGraph.
789    n: `int` the number of `None` optionals to make.
790
791  Returns:
792    A list of tensors in func_graph.
793  """
794  with func_graph.as_default():
795    return [gen_dataset_ops.optional_none() for _ in range(n)]
796
797
798def _create_fakeparams(func_graph, template_tensors):
799  """Create FakeParams for the XLA case."""
800  with func_graph.as_default():
801    return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape)
802            for t in template_tensors]
803
804
805def _check_same_outputs(op_type, graphs):
806  """Raises an error if `graphs` have different outputs."""
807
808  def error(branch_idx, error_detail):
809    raise TypeError(
810        "{b0_name} and {bn_name} arguments to {op_name} must have the same "
811        "number, type, and overall structure of return values.\n"
812        "\n"
813        "{b0_name} output: {b0_out}\n"
814        "{bn_name} output: {bn_out}\n"
815        "\n"
816        "Error details:\n"
817        "{detail}".format(
818            b0_name="true_fn" if op_type == _COND else "branches[0]",
819            bn_name=("false_fn" if op_type == _COND else
820                     "branches[{}]".format(branch_idx)),
821            op_name="tf.cond" if op_type == _COND else "tf.switch_case",
822            b0_out=graphs[0].structured_outputs,
823            bn_out=graphs[branch_idx].structured_outputs,
824            detail=error_detail))
825
826  for b in range(1, len(graphs)):
827    try:
828      nest.assert_same_structure(
829          graphs[0].structured_outputs,
830          graphs[b].structured_outputs,
831          expand_composites=True)
832    except (ValueError, TypeError) as e:
833      error(b, str(e))
834
835    op_type_str = "cond" if op_type == _COND else "case"
836    if len(graphs[0].outputs) != len(graphs[b].outputs):
837      raise ValueError("Lengths of branch outputs of {op_type} must match.\n"
838                       "len(graphs[0].outputs): {len_0}\n"
839                       "len(graphs[{b}].outputs): {len_b}\n".format(
840                           op_type=op_type_str,
841                           len_0=len(graphs[0].outputs),
842                           b=b,
843                           len_b=len(graphs[b].outputs)))
844    for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs):
845      if b0_out.dtype != bn_out.dtype:
846        error(b, "%s and %s have different types" % (b0_out, bn_out))
847
848
849def _get_output_shapes(*branch_graph_outputs):
850  output_shapes = []
851  for out_by_branch in zip(*branch_graph_outputs):
852    shape = out_by_branch[0].shape
853    for other_out in out_by_branch[1:]:
854      shape = shape.most_specific_compatible_shape(other_out.shape)
855    output_shapes.append(shape)
856  return output_shapes
857
858
859def _copy_handle_data(external_tensors, *branch_graph_outputs):
860  """Combines shapes in handle data and sets metadata on `external_tensors`."""
861  for tensors in zip(external_tensors, *branch_graph_outputs):
862    external = tensors[0]
863    internal = tensors[1:]
864    internal_handle_data = []
865    for tensor in internal:
866      handle_data = handle_data_util.get_resource_handle_data(tensor)
867      # NOTE: Assumes handle data has only one ShapeAndType entry. It's
868      # unclear how to combine different lengths across branches.
869      if not handle_data.is_set or len(handle_data.shape_and_type) != 1:
870        break
871      internal_handle_data.append(handle_data)
872    else:  # There is handle data, so we need to combine it.
873      combined_shape = tensor_shape.TensorShape(None)
874      combined_dtype = None
875      for handle_data in internal_handle_data:
876        handle_shape = tensor_shape.TensorShape(
877            handle_data.shape_and_type[0].shape)
878        combined_shape = combined_shape.most_specific_compatible_shape(
879            handle_shape)
880        if combined_dtype is None:
881          combined_dtype = handle_data.shape_and_type[0].dtype
882        elif handle_data.shape_and_type[0].dtype != combined_dtype:
883          # Variants from different branches have different dtypes. The
884          # combined variant has no static dtype.
885          combined_dtype = types_pb2.DT_INVALID
886      combined_handle_data = internal_handle_data[0]
887      combined_handle_data.shape_and_type[0].shape.CopyFrom(
888          combined_shape.as_proto())
889      combined_handle_data.shape_and_type[0].dtype = combined_dtype
890      handle_data_util.set_handle_data(external, combined_handle_data)
891
892
893def verify_captures(op_type, branch_graphs):
894  """Verify that a branch's tensor is not accessed in another branch fn."""
895  # Note: It is technically not possible for lower-branch_index branches to
896  # capture tensors from higher-branch_index branches, because of the order of
897  # branch graph construction, but we check all for completeness and to
898  # guard against potential future changes.
899  other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)}
900  for i, branch_graph in enumerate(branch_graphs):
901    for t in branch_graph.external_captures:
902      if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs:
903        branch_names = ["true_fn", "false_fn"] if op_type == _COND else [
904            "branch {}".format(bi) for bi in range(len(branch_graphs))]
905        raise ValueError(
906            "Tensor {tname} in {b0name} is accessed from {b1name}.".format(
907                tname=t.name,
908                b0name=branch_names[other_branch_graphs[t.graph]],
909                b1name=branch_names[i]))
910
911
912class _CondGradFuncGraph(util.CondBranchFuncGraph):
913  """FuncGraph for the gradient function of the branch of an If op.
914
915  Handles wrapping and unwrapping intermediate values that are captured by the
916  gradient computation in optionals.
917
918  Attributes:
919    op_needs_rewrite: True if any intermediates were captured, meaning the
920      forward If op needs to be written to output the wrapped intermediates.
921  """
922
923  def __init__(self, name, forward_graph):
924    super(_CondGradFuncGraph, self).__init__(
925        name, collections=ops.get_default_graph()._collections)  # pylint: disable=protected-access
926    self.op_needs_rewrite = False
927    self._forward_graph = forward_graph
928    # Maps from forward intermediate tensor -> the unwrapped captured
929    # intermediate.
930    self._indirect_captures = {}
931    # Maps unwrapped intermediate -> optional-wrapped intermediate in the
932    # forward graph.
933    self._wrapped_intermediates = collections.OrderedDict()
934    # Raw intermediates captured from the forward graph. Populated iff we're in
935    # an XLA context.
936    self._xla_intermediates = []
937    # Maps forward intermediate constant valued tensor's id to the constant
938    # created in this graph for that tensor.
939    self._captured_constants = {}
940
941  @property
942  def wrapped_intermediates(self):
943    """The optional-wrapped intermediates captured from the forward graph."""
944    return list(self._wrapped_intermediates.values())
945
946  @property
947  def xla_intermediates(self):
948    """Raw intermediates captured from the forward graph if XLA is enabled."""
949    return self._xla_intermediates
950
951  def _capture_helper(self, tensor, name):
952    if (tensor.graph is not self._forward_graph or
953        any(tensor is t for t in self._forward_graph.inputs) or
954        any(tensor is t for t in self._forward_graph.outputs)):
955      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
956
957    tensor_id = ops.tensor_id(tensor)
958
959    # If `tensor` is a graph-building time constant, we create a constant with
960    # the same value in the backward graph instead of capturing it.
961    if tensor_id in self._captured_constants:
962      return self._captured_constants[tensor_id]
963    elif constant_op.is_constant(tensor):
964      self._captured_constants[tensor_id] = constant_op.constant(
965          tensor_util.constant_value(tensor), dtype=tensor.dtype)
966      return self._captured_constants[tensor_id]
967
968    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
969      # XLA does not yet support optionals, so capture intermediates directly.
970      # TODO(skyewm,jpienaar): can XLA support optionals?
971      if all(tensor is not capture for capture in self.external_captures):
972        self.xla_intermediates.append(tensor)
973        self.op_needs_rewrite = True
974      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
975
976    captured_tensor = self._indirect_captures.get(tensor_id)
977    if captured_tensor is not None:
978      return captured_tensor
979
980    # 'tensor' is an uncaptured intermediate in the forward graph.
981    # If it is not a resource, we wrap it in an optional in the forward graph
982    # and capture the optional normally. We then unwrap the captured optional
983    # value in the gradient graph to get the raw intermediate value.
984    # If it is a resource, we trace the resource up to the input in the forward
985    # graph and capture that.
986
987    if tensor.dtype == dtypes.resource:
988      # Index of the forward graph input corresponding to the resource tensor.
989      index = util.resource_input_index(
990          tensor.name, [t.name for t in self._forward_graph.inputs],
991          {op.name: op.node_def for op in self._forward_graph.get_operations()},
992          self._forward_graph._functions)
993      # This gets mapped to the corresponding If op input in
994      # `_resolve_grad_inputs`.
995      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
996          self._forward_graph.inputs[index], name)
997    else:
998      if tensor_id not in self._wrapped_intermediates:
999        # If the gradient has already been computed for this If op, 'tensor' may
1000        # already be wrapped.
1001        for consumer in tensor.consumers():
1002          if (consumer.type == "OptionalFromValue" and
1003              any(consumer.outputs[0] is output
1004                  for output in self._forward_graph.outputs)):
1005            optional = consumer.outputs[0]
1006            break
1007        else:
1008          # 'tensor' hasn't been wrapped, do it now.
1009          with self._forward_graph.as_default():
1010            optional = gen_dataset_ops.optional_from_value([tensor])
1011          self.op_needs_rewrite = True
1012        self._wrapped_intermediates[tensor_id] = optional
1013
1014      optional = self._wrapped_intermediates[tensor_id]
1015      captured_optional = super(_CondGradFuncGraph,
1016                                self)._capture_helper(optional, name)
1017      captured_tensor = gen_dataset_ops.optional_get_value(
1018          captured_optional, [tensor.dtype], [tensor.shape])[0]
1019
1020    self._indirect_captures[tensor_id] = captured_tensor
1021    return captured_tensor
1022
1023
1024def indexed_case(branch_index,
1025                 branch_fns,
1026                 name="indexed_case",
1027                 lower_using_switch_merge=None):
1028  """Like conv_v2, except emits a Case op instead of an If."""
1029  if isinstance(branch_index, int):
1030    raise TypeError("branch_index must not be a Python int", branch_index)
1031
1032  with ops.name_scope(name) as scope:
1033    branch_names = [
1034        util.unique_fn_name(scope, "branch{}".format(b))
1035        for b in range(len(branch_fns))
1036    ]
1037
1038    # Automatic control dependencies are added in defuns, but not in v1
1039    # graphs. Propagate that behavior here.
1040    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
1041    branch_index = ops.convert_to_tensor(branch_index, name="branch_index")
1042
1043    branch_graphs = []
1044    for branch_name, branch_fn in zip(branch_names, branch_fns):
1045      branch_graphs.append(
1046          func_graph_module.func_graph_from_py_func(
1047              branch_name,
1048              branch_fn,
1049              [],
1050              {},
1051              func_graph=util.CondBranchFuncGraph(
1052                  branch_name,
1053                  collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
1054              add_control_dependencies=add_control_dependencies,
1055              op_return_value=branch_index))
1056
1057    verify_captures(_CASE, branch_graphs)
1058    return _build_case(
1059        branch_index,
1060        branch_graphs, [g.external_captures for g in branch_graphs],
1061        name=scope,
1062        lower_using_switch_merge=lower_using_switch_merge)
1063
1064
1065@ops.RegisterGradient("Case")
1066@ops.RegisterGradient("StatelessCase")
1067def _CaseGrad(op, *grads):  # pylint: disable=invalid-name
1068  """The gradient of a Case op produced by tf.switch_case."""
1069  # Get the Case operator (this logic handles the case where op is a MockOp)
1070  case_op = op.outputs[0].op
1071  branch_graphs = get_func_graphs(case_op)
1072  assert branch_graphs
1073  # Note: op.graph != ops.get_default_graph() when we are computing the gradient
1074  # of a nested cond.
1075  for branch_graph in branch_graphs:
1076    assert branch_graph.outer_graph == case_op.graph
1077
1078  # Create grad functions that compute the gradient of the branch forward
1079  # graphs. These functions will capture tensors from the forward pass
1080  # functions.
1081  branch_grad_graphs = []
1082  for branch_graph in branch_graphs:
1083    branch_grad_graphs.append(
1084        _create_grad_func(branch_graph, grads,
1085                          util.unique_grad_fn_name(branch_graph.name)))
1086  # Replaces output None grads with zeros if at least one branch has non-None
1087  # grad at that index.
1088  _create_zeros_for_none_grads(branch_graphs, branch_grad_graphs)
1089
1090  if any(g.op_needs_rewrite for g in branch_grad_graphs):
1091    # Modify 'op' to output the intermediates needed by the grad functions. Note
1092    # that all needed intermediates are wrapped in optionals. Each optional
1093    # intermediate output will have a value iff its corresponding branch is
1094    # taken.
1095    # NOTE(bjp): if there are any active sessions, this modification to `op`
1096    # may make them unrunnable!
1097
1098    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
1099      # XLA does not yet support optionals, so output intermediates directly and
1100      # make them match via FakeParams, which can be converted to zeros in XLA.
1101      # TODO(bjp,jpienaar): can XLA support optionals?
1102      branches_intermediates = [
1103          branch_grad_graph.xla_intermediates
1104          for branch_grad_graph in branch_grad_graphs
1105      ]
1106      extra_branch_outputs = _make_intermediates_match_xla(
1107          branch_graphs, branches_intermediates)
1108    else:
1109      branch_intermediates = [
1110          g.wrapped_intermediates for g in branch_grad_graphs
1111      ]
1112      # Make outputs match by adding none optionals.
1113      extra_branch_outputs = _make_intermediates_match(branch_graphs,
1114                                                       branch_intermediates)
1115
1116    for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs):
1117      branch_graph.outputs.extend(extra_outputs)
1118    # TODO(bjp): indicate it's an internal bug if this fails.
1119    _check_same_outputs(_CASE, branch_graphs)
1120
1121    for branch_graph in branch_graphs:
1122      branch_graph.name += "_rewritten"
1123
1124    case_op._set_func_list_attr("branches", [
1125        util.create_new_tf_function(branch_graph)
1126        for branch_graph in branch_graphs
1127    ])
1128    case_op._set_type_list_attr("Tout", branch_graphs[0].output_types)
1129    case_op._set_shape_list_attr("output_shapes",
1130                                 branch_graphs[0].output_shapes)
1131    case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]],
1132                         [t.shape for t in extra_branch_outputs[0]])
1133
1134  # Resolve references to forward graph tensors in grad graphs and ensure
1135  # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
1136  branches_grad_inputs = [
1137      _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph,
1138      branch_grad_graph in zip(branch_graphs, branch_grad_graphs)
1139  ]
1140
1141  # This modifies the graphs in branch_grad_graphs.
1142  _make_output_composite_tensors_match(_CASE, branch_grad_graphs)
1143
1144  try:
1145    lowering = case_op._get_attr_bool("_lower_using_switch_merge")
1146  except errors_impl.NotFoundError:
1147    lowering = None
1148
1149  outputs = _build_case(
1150      case_op.inputs[0],
1151      branch_grad_graphs,
1152      branches_grad_inputs,
1153      name="gradient",
1154      lower_using_switch_merge=lowering)
1155
1156  # The predicate has no gradient.
1157  return [None] + outputs
1158
1159
1160def _build_case(branch_index,
1161                branch_graphs,
1162                branch_inputs,
1163                name=None,
1164                lower_using_switch_merge=None):
1165  """Creates an `Case` op from `branch_index`, branch graphs and inputs.
1166
1167  Note that this modifies `branch_graphs` to make the inputs match, and to
1168  output all intermediates values so they're available for the gradient
1169  computation.
1170
1171  `branch_graphs` need not have the same input types, but they must
1172  have the same output types.
1173
1174  Args:
1175    branch_index: integer Tensor
1176    branch_graphs: List of FuncGraph
1177    branch_inputs: List of lists of Tensors to be passed to corresponding
1178      branch_graph as input.
1179    name: the name for the Case op.
1180    lower_using_switch_merge: Lower this op using switch merge ops (optional).
1181
1182  Returns:
1183    A list of Tensors which are the outputs of the Case op. Does not include
1184    added intermediate outputs.
1185  """
1186  _make_indexed_slices_indices_types_match(_CASE, branch_graphs)
1187  _check_same_outputs(_CASE, branch_graphs)
1188
1189  # Add inputs to branch_graphs to make them match. Note that this modifies the
1190  # graphs in `branch_graphs`.
1191  case_inputs = _make_inputs_match(branch_graphs, branch_inputs)
1192
1193  stateful_ops = []
1194  for bg in branch_graphs:
1195    stateful_ops.extend([
1196        op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op)
1197    ])
1198
1199  if stateful_ops:
1200    op_fn = gen_functional_ops.case
1201  else:
1202    op_fn = gen_functional_ops.stateless_case
1203
1204  # Create the Case op.
1205  with ops.control_dependencies(
1206      sum((list(bg.control_captures) for bg in branch_graphs), [])):
1207
1208    def _make_op(inputs):
1209      case_op, tensors = util.get_op_and_outputs(op_fn(
1210          branch_index,
1211          inputs, [t.dtype for t in branch_graphs[0].outputs],
1212          [util.create_new_tf_function(g) for g in branch_graphs],
1213          output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
1214          name=name))
1215      _copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
1216      if case_op is not None:
1217        util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
1218        util.maybe_propagate_compile_time_consts_in_xla(case_op)
1219        _set_read_only_resource_inputs_attr(case_op, branch_graphs)
1220        # Prevent fetching since the variant outputs can't be fetched directly.
1221        case_op.graph.prevent_fetching(case_op)
1222
1223        # Store the branch graphs so they can be reused during the gradient
1224        # pass.
1225        for i, bg in enumerate(branch_graphs):
1226          bg.outer_graph = ops.get_default_graph()
1227          setattr(case_op, "_branch_graph_{}".format(i), bg)
1228
1229      return tensors
1230    tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs)
1231
1232  # Return identities for each output of the Case op, rather than the output of
1233  # the Case op directly. This makes pruning work if the output of switch_case()
1234  # is fetched: the lowering pass converts the Case outputs into IdentityN
1235  # outputs, which if fetched will cause all ops in the taken branch to be run
1236  # (since it takes all merge ops as input). After lowering, each output
1237  # identity op will end up with only the appropriate merge op as input.
1238  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
1239  # correct output structure
1240  tensors = [array_ops.identity(t) for t in tensors]
1241
1242  return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)
1243
1244
1245def _set_read_only_resource_inputs_attr(op, branch_graphs):
1246  """Sets the list of resource inputs which are read-only.
1247
1248  This is used by AutomaticControlDependencies.
1249
1250  Args:
1251    op: If or Case Operation.
1252    branch_graphs: List of branch FuncGraphs.
1253  """
1254  # The first entry in `op.inputs` is the predicate which is not passed to
1255  # branch graphs so len(branch_graph[i].inputs) == len(op.inputs) - 1.
1256  read_only_indices = set(range(len(op.inputs) - 1))
1257  for branch_graph in branch_graphs:
1258    assert len(branch_graph.inputs) == len(op.inputs) - 1, "should never happen"
1259    if not read_only_indices:
1260      break
1261    branch_read_only_indices = acd.get_read_only_resource_input_indices_graph(
1262        branch_graph)
1263    read_only_indices = read_only_indices.intersection(branch_read_only_indices)
1264  # Convert indices in `branch_graphs[i].inputs` to `op.inputs`.
1265  read_only_indices = [i + 1 for i in read_only_indices]
1266  ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
1267                        sorted(read_only_indices))
1268