• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Control Flow Operations.
16
17See the [autograph](https://www.tensorflow.org/guide/autograph) guide.
18"""
19# pylint: disable=g-bad-name
20import abc
21import collections
22import functools
23
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.core.protobuf import control_flow_pb2
26from tensorflow.python.eager import context
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import indexed_slices
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_spec
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.framework import type_spec
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_util as util
39from tensorflow.python.ops import gen_array_ops
40from tensorflow.python.ops import gen_control_flow_ops
41from tensorflow.python.ops import gen_functional_ops
42from tensorflow.python.ops import gen_logging_ops
43from tensorflow.python.ops import gen_math_ops
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops import tensor_array_ops
46# go/tf-wildcard-import
47# pylint: disable=wildcard-import,undefined-variable
48from tensorflow.python.ops.gen_control_flow_ops import *
49# pylint: enable=wildcard-import
50from tensorflow.python.platform import tf_logging as logging
51from tensorflow.python.util import compat
52from tensorflow.python.util import deprecation
53from tensorflow.python.util import dispatch
54from tensorflow.python.util import nest
55from tensorflow.python.util import tf_should_use
56from tensorflow.python.util import variable_utils
57from tensorflow.python.util.lazy_loader import LazyLoader
58from tensorflow.python.util.tf_export import tf_export
59
60# This is to avoid a circular dependency:
61# cond_v2 -> gradients_util -> control_flow_ops
62cond_v2 = LazyLoader("cond_v2", globals(),
63                     "tensorflow.python.ops.cond_v2")
64
65# This is to avoid circular dependencies:
66# while_v2 -> control_flow_ops
67# while_v2 -> gradients_util -> control_flow_ops
68while_v2 = LazyLoader("while_v2", globals(),
69                      "tensorflow.python.ops.while_v2")
70
71# def_function also uses cond
72def_function = LazyLoader(
73    "def_function", globals(),
74    "tensorflow.python.eager.def_function")
75
76
77# We override the 'tuple' for a control flow op, so we keep python's
78# existing 'tuple' for later use in this module.
79_basetuple = tuple
80
81
82def _summarize_eager(tensor, summarize=None):
83  """Returns a summarized string representation of eager `tensor`.
84
85  Args:
86    tensor: EagerTensor to summarize
87    summarize: Include these many first elements of `array`
88  """
89  # Emulate the behavior of Tensor::SummarizeValue()
90  if summarize is None:
91    summarize = 3
92  elif summarize < 0:
93    summarize = array_ops.size(tensor)
94
95  # reshape((-1,)) is the fastest way to get a flat array view
96  if tensor._rank():  # pylint: disable=protected-access
97    flat = tensor.numpy().reshape((-1,))
98    lst = [str(x) for x in flat[:summarize]]
99    if len(lst) < flat.size:
100      lst.append("...")
101  else:
102    # tensor.numpy() returns a scalar for zero dimensional arrays
103    if gen_math_ops.not_equal(summarize, 0):
104      lst = [str(tensor.numpy())]
105    else:
106      lst = []
107
108  return ", ".join(lst)
109
110
111# pylint: disable=protected-access
112
113
114# Assert and Print are special symbols in python, so we must
115# use an upper-case version of them.
116@tf_export("debugging.Assert", "Assert")
117@dispatch.add_dispatch_support
118@tf_should_use.should_use_result
119def Assert(condition, data, summarize=None, name=None):
120  """Asserts that the given condition is true.
121
122  If `condition` evaluates to false, print the list of tensors in `data`.
123  `summarize` determines how many entries of the tensors to print.
124
125  Args:
126    condition: The condition to evaluate.
127    data: The tensors to print out when condition is false.
128    summarize: Print this many entries of each tensor.
129    name: A name for this operation (optional).
130
131  Returns:
132    assert_op: An `Operation` that, when executed, raises a
133    `tf.errors.InvalidArgumentError` if `condition` is not true.
134    @compatibility(eager)
135    returns None
136    @end_compatibility
137
138  Raises:
139    @compatibility(TF1)
140    When in TF V1 mode (that is, outside `tf.function`) Assert needs a control
141    dependency on the output to ensure the assertion executes:
142
143  ```python
144  # Ensure maximum element of x is smaller or equal to 1
145  assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])
146  with tf.control_dependencies([assert_op]):
147    ... code using x ...
148  ```
149
150    @end_compatibility
151  """
152  if context.executing_eagerly():
153    if not condition:
154      xs = ops.convert_n_to_tensor(data)
155      data_str = [_summarize_eager(x, summarize) for x in xs]
156      raise errors.InvalidArgumentError(
157          node_def=None,
158          op=None,
159          message="Expected '%s' to be true. Summarized data: %s" %
160          (condition, "\n".join(data_str)))
161    return
162
163  with ops.name_scope(name, "Assert", [condition, data]) as name:
164    xs = ops.convert_n_to_tensor(data)
165    if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs):
166      # As a simple heuristic, we assume that string and int32 are
167      # on host to avoid the need to use cond. If it is not case,
168      # we will pay the price copying the tensor to host memory.
169      return gen_logging_ops._assert(condition, data, summarize, name="Assert")
170    else:
171      condition = ops.convert_to_tensor(condition, name="Condition")
172
173      def true_assert():
174        return gen_logging_ops._assert(
175            condition, data, summarize, name="Assert")
176
177      guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
178      if context.executing_eagerly():
179        return
180      return guarded_assert.op
181
182
183def _Identity(tensor, name=None):
184  """Return a tensor with the same shape and contents as the input tensor.
185
186  Args:
187    tensor: A Tensor.
188    name: A name for this operation (optional).
189
190  Returns:
191    A Tensor with the same type and value as the input Tensor.
192  """
193  tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
194  if isinstance(tensor, ops.Tensor):
195    if tensor.dtype._is_ref_dtype:  # pylint: disable=protected-access
196      return gen_array_ops.ref_identity(tensor, name=name)
197    else:
198      return array_ops.identity(tensor, name=name)
199  elif isinstance(tensor, composite_tensor.CompositeTensor):
200    return nest.map_structure(_Identity, tensor, expand_composites=True)
201  else:
202    raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
203                    f"Received: {type(tensor)}.")
204
205
206def _NextIteration(tensor, name=None):
207  tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
208  if isinstance(tensor, ops.Tensor):
209    if tensor.dtype._is_ref_dtype:  # pylint: disable=protected-access
210      return ref_next_iteration(tensor, name=name)
211    else:
212      return next_iteration(tensor, name=name)
213  elif isinstance(tensor, composite_tensor.CompositeTensor):
214    return nest.map_structure(_NextIteration, tensor, expand_composites=True)
215  else:
216    raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
217                    f"Received: {type(tensor)}.")
218
219
220def _Enter(tensor,
221           frame_name,
222           is_constant=False,
223           parallel_iterations=10,
224           use_ref=True,
225           use_input_shape=True,
226           name=None):
227  """Creates or finds a child frame, and makes `tensor` available to it.
228
229  The unique `frame_name` is used by the `Executor` to identify frames. If
230  `is_constant` is true, `tensor` is a constant in the child frame; otherwise
231  it may be changed in the child frame. At most `parallel_iterations`
232  iterations are run in parallel in the child frame.
233
234  Args:
235    tensor: The tensor to be made available to the child frame.
236    frame_name: The name of the child frame.
237    is_constant: If true, the output is constant within the child frame.
238    parallel_iterations: The number of iterations allowed to run in parallel.
239    use_ref: If true, use ref_enter if tensor is of ref type.
240    use_input_shape: If true, set the result's shape based on tensor's shape.
241    name: A name for this operation (optional).
242
243  Returns:
244    The same tensor as `tensor`.
245
246  Raises:
247    ValueError: If any tensor in `tensor` has a less specific shape
248      than its corresponding shape in `shape_invariant`.
249  """
250  tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
251  if isinstance(tensor, ops.Tensor):
252    if tensor.dtype._is_ref_dtype and use_ref:  # pylint: disable=protected-access
253      result = gen_control_flow_ops.ref_enter(
254          tensor, frame_name, is_constant, parallel_iterations, name=name)
255    else:
256      result = gen_control_flow_ops.enter(
257          tensor, frame_name, is_constant, parallel_iterations, name=name)
258    if use_input_shape:
259      result.set_shape(tensor.get_shape())
260    return result
261  elif isinstance(tensor, composite_tensor.CompositeTensor):
262
263    def enter_component(t):
264      return _Enter(t, frame_name, is_constant, parallel_iterations, use_ref,
265                    use_input_shape)
266
267    return nest.map_structure(enter_component, tensor, expand_composites=True)
268  else:
269    raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
270                    f"Received: {type(tensor)}.")
271
272
273def exit(tensor, name=None):  # pylint: disable=redefined-builtin
274  """Exits the current frame to its parent frame.
275
276  Exit makes its input `tensor` available to the parent frame.
277
278  Args:
279    tensor: The tensor to be made available to the parent frame.
280    name: A name for this operation (optional).
281
282  Returns:
283    The same tensor as `tensor`.
284  """
285  tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
286  if isinstance(tensor, ops.Tensor):
287    if tensor.dtype._is_ref_dtype:  # pylint: disable=protected-access
288      return gen_control_flow_ops.ref_exit(tensor, name)
289    else:
290      return gen_control_flow_ops._exit(tensor, name)
291  elif isinstance(tensor, composite_tensor.CompositeTensor):
292    return nest.map_structure(exit, tensor, expand_composites=True)
293  else:
294    raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
295                    f"Received: {type(tensor)}.")
296
297
298def switch(data, pred, dtype=None, name=None):
299  """Forwards `data` to an output determined by `pred`.
300
301  If `pred` is false, the `data` input is forwarded to the first output.
302  Otherwise, the data goes to the second output.
303
304  This op handles `Tensor`s and `IndexedSlices`.
305
306  Args:
307    data: The tensor to be forwarded to the appropriate output.
308    pred: A scalar that specifies which output port will receive data.
309    dtype: Optional element type for the returned tensor. If missing, the type
310      is inferred from the type of `value`.
311    name: A name for this operation (optional).
312
313  Returns:
314    `(output_false, output_true)`: If `pred` is true, data will be forwarded
315    to `output_true`, otherwise it goes to `output_false`.
316  """
317  with ops.name_scope(name, "Switch", [data, pred]) as name:
318    data = ops.internal_convert_to_tensor_or_composite(
319        data, dtype=dtype, name="data", as_ref=True)
320    pred = ops.convert_to_tensor(pred, name="pred")
321    if isinstance(data, ops.Tensor):
322      return gen_control_flow_ops.switch(data, pred, name=name)
323    else:
324      if not isinstance(data, composite_tensor.CompositeTensor):
325        raise TypeError(
326            "'data' must be a Tensor or CompositeTensor. "
327            f"Received: {type(data)}.")
328      tensors = nest.flatten(data, expand_composites=True)
329      mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors]
330      mapped_f, mapped_t = zip(*mapped)
331      return (nest.pack_sequence_as(data, mapped_f, expand_composites=True),
332              nest.pack_sequence_as(data, mapped_t, expand_composites=True))
333
334
335def _SwitchRefOrTensor(data, pred, name="Switch"):
336  """Forwards `data` to an output determined by `pred`.
337
338  If `pred` is false, the `data` input is forwarded to the first output.
339  Otherwise, the data goes to the second output.
340
341  This op handles `Tensor`s and `IndexedSlices`.
342
343  Args:
344    data: The tensor to be forwarded to the appropriate output.
345    pred: A scalar that specifies which output port will receive data.
346    name: A name for this operation (optional).
347
348  Returns:
349    `(output_false, output_true)`: If `pred` is true, data will be forwarded to
350    `output_true`, otherwise it goes to `output_false`.
351
352  Raises:
353    TypeError: if data is not a Tensor or IndexedSlices
354  """
355  data = ops.convert_to_tensor_or_composite(data, name="data")
356  # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
357  # addresses the following scenario.
358  #
359  # Assume you execute Optimizer.apply_gradients() in a branch of a cond().
360  #
361  # 1. The update op is created inside a `with ops.colocate(var):` block
362  #
363  # 2. Some tensor `data` is captured and a switch is created in a
364  #    `with ops.colocate_with(data):` block.
365  #
366  # with ops.colocate_with(var):
367  #  with ops.colocate_with(data):
368  #    op = ...
369  #
370  # var and data may be pinned to different devices, so we want to ops
371  # created within ops.colocate_with(data) to ignore the existing stack.
372  with ops.colocate_with(data, ignore_existing=True):
373    if isinstance(data, ops.Tensor):
374      if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
375        return ref_switch(data, pred, name=name)
376    return switch(data, pred, name=name)
377
378
379def merge(inputs, name=None):
380  """Returns the value of an available element of `inputs`.
381
382  This op tests each of the tensors in `inputs` in turn to determine if any of
383  them is available. If it finds an available tensor, it returns it and its
384  index in `inputs`.
385
386  It is an error if more than one tensor in `inputs` is available. If no tensor
387  in `inputs` is available, the returned tensor and index are not set.
388
389  This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
390  `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
391  before merging.
392
393  Args:
394    inputs: The input tensors, at most one of which is available.
395    name: A name for this operation (optional).
396
397  Returns:
398    A tuple containing the chosen input tensor and its index in `inputs`.
399
400  Raises:
401    ValueError: If any of the inputs is None, or inputs are IndexedSlices and
402      some but not all have a dense_shape property.
403  """
404  if any(inp is None for inp in inputs):
405    raise ValueError("At least one of the merge inputs is None: %s" % inputs)
406  with ops.name_scope(name, "Merge", inputs) as name:
407    inputs = [
408        ops.internal_convert_to_tensor_or_composite(inp, as_ref=True)
409        for inp in inputs
410    ]
411    if all(isinstance(v, ops.Tensor) for v in inputs):
412      if all(v.dtype._is_ref_dtype for v in inputs):  # pylint: disable=protected-access
413        return gen_control_flow_ops.ref_merge(inputs, name)
414      else:
415        return gen_control_flow_ops.merge(inputs, name)
416    else:
417      # If there is a mix of tensors and indexed slices, then convert the
418      # tensors to indexed slices.
419      if all(
420          isinstance(v, (indexed_slices.IndexedSlices, ops.Tensor))
421          for v in inputs):
422        inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
423
424      for v in inputs:
425        if not isinstance(v, composite_tensor.CompositeTensor):
426          raise TypeError("Type %s not supported" % type(v))
427
428      for v in inputs[1:]:
429        nest.assert_same_structure(inputs[0], v, expand_composites=True)
430
431      flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs]
432      merged_results = [
433          gen_control_flow_ops.merge(component)
434          for component in zip(*flat_inputs)
435      ]
436      flat_merged = [tensor for (tensor, _) in merged_results]
437      chosen_index = merged_results[0][1]
438      merged_inputs = nest.pack_sequence_as(
439          inputs[0], flat_merged, expand_composites=True)
440      return (merged_inputs, chosen_index)
441
442
443def _convert_tensorarray_to_flow(tensor_or_tensor_array):
444  if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
445    return tensor_or_tensor_array.flow
446  else:
447    return tensor_or_tensor_array
448
449
450def _convert_flow_to_tensorarray(tensor_or_tensor_array, tensor_or_flow):
451  if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
452    return tensor_array_ops.build_ta_with_new_flow(tensor_or_tensor_array,
453                                                   tensor_or_flow)
454  else:
455    return tensor_or_flow
456
457
458def _convert_to_tensor_or_composite_or_tensorarray(var):
459  if isinstance(var, tensor_array_ops.TensorArray):
460    return var
461  return ops.convert_to_tensor_or_composite(var)
462
463
464# TODO(xjun): replace this with is_subtype_of after it is landed.
465def _ShapeLessThanOrEqual(shape1, shape2):
466  if shape2.dims is None:
467    return True
468  if shape1.ndims != shape2.ndims:
469    return False
470  for dim1, dim2 in zip(shape1.dims, shape2.dims):
471    if dim2.value is not None and dim1.value != dim2.value:
472      return False
473  return True
474
475
476def _shape_invariant_to_type_spec(var, shape=None):
477  """Converts a shape invariant to a TypeSpec.
478
479  If `var` is a TensorArray, it will first be converted to its flow.
480
481  Args:
482    var: The tensor, tensor array or composite tensor whose shape is described
483      by the shape invariant.
484    shape: A `TypeSpec` or `TensorShape`.  If `shape` is already a `TypeSpec`,
485      then it is simply returned as-is.
486
487  Returns:
488    A `TypeSpec` for `var`, consistent with the given shape.
489
490  Raises:
491    TypeError: If `shape` is a TypeSpec and not compatible with `var`.
492    TypeError: If `shape` is not None, a TypeSpec, or a TensorShape.
493    TypeError: If `shape` is a TensorShape, `var` is a CompositeTensor, and
494      `var` doesn't implement the `_shape_invariant_to_type_spec` method.
495  """
496  var = _convert_tensorarray_to_flow(var)
497  if shape is None:
498    return type_spec.type_spec_from_value(var)
499  elif isinstance(shape, type_spec.TypeSpec):
500    if not shape.is_compatible_with(var):
501      raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
502    return shape
503  elif not isinstance(shape, tensor_shape.TensorShape):
504    raise TypeError(
505        "'shape' must be one of TypeSpec, TensorShape or None. "
506        f"Received: {type(shape)}")
507
508  if isinstance(var, ops.Tensor):
509    return tensor_spec.TensorSpec(shape, var.dtype)
510  else:
511    try:
512      return var._shape_invariant_to_type_spec(shape)  # pylint: disable=protected-access
513    except NotImplementedError as e:
514      raise TypeError(
515          f"To describe or constrain a {type(var).__name__}, use a "
516          f"{type(var._type_spec).__name__} instead of a TensorShape.") from e  # pylint: disable=protected-access
517
518
519def _EnforceShapeInvariant(merge_var, next_var):
520  """Check if the shapes of the loops variables are invariants.
521
522  Args:
523    merge_var: The tensor representing the initial values of the loop
524      variables.
525    next_var: The tensor representing the values of the loop variables
526      after one loop iteration.
527
528  Raises:
529    ValueError: If any tensor in `merge_var` has a more specific shape than
530      its corresponding tensor in `next_var`.
531  """
532  if isinstance(merge_var, ops.Tensor):
533    m_shape = merge_var.get_shape()
534    n_shape = next_var.get_shape()
535    if not _ShapeLessThanOrEqual(n_shape, m_shape):
536      enter = merge_var.op.inputs[0].op
537      assert util.IsLoopEnter(enter)
538      input_t = enter.inputs[0]
539      raise ValueError(
540          "Input tensor '%s' enters the loop with shape %s, but has shape %s "
541          "after one iteration. To allow the shape to vary across iterations, "
542          "use the `shape_invariants` argument of tf.while_loop to specify a "
543          "less-specific shape." % (input_t.name, input_t.shape, n_shape))
544  else:
545    raise TypeError("'merge_var' must be a Tensor. "
546                    f"Received: {type(merge_var)}.")
547
548
549def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
550  """Add NextIteration and back edge from v to m."""
551  if isinstance(m, ops.Tensor):
552    v = ops.convert_to_tensor(v)
553    v = _NextIteration(v)
554    if enforce_shape_invariant:
555      # Make sure the shapes of loop outputs are correct. We do this before
556      # calling _update_input, which will raise a less-helpful error message if
557      # the types don't match.
558      # TODO(skyewm): call this for other cases below (needs testing)
559      _EnforceShapeInvariant(m, v)
560    m.op._update_input(1, v)  # pylint: disable=protected-access
561  elif isinstance(m, composite_tensor.CompositeTensor):
562    # pylint: disable=protected-access
563    def update_component(m_component, v_component):
564      m_component.op._update_input(1, v_component)
565
566    if isinstance(m, indexed_slices.IndexedSlices):
567      v = math_ops._as_indexed_slices(v, optimize=False)
568    # pylint: enable=protected-access
569    v = _NextIteration(v)
570    return nest.map_structure(update_component, m, v, expand_composites=True)
571  else:
572    raise TypeError("'m' must be a Tensor or CompositeTensor. "
573                    f"Received: {type(m)}.")
574  return v
575
576
577class ControlFlowContext(metaclass=abc.ABCMeta):
578  """The base class for control flow context.
579
580  The usage pattern is a sequence of (Enter, Exit) followed by a final
581  ExitResult.
582
583  We maintain the following state for control flow contexts during graph
584  construction:
585   1. graph has _control_flow_context: the current context used to
586      construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()
587   2. op has _control_flow_context: the context to which the op belongs.
588      Set at the time the op is created. Immutable.
589   3. A ControlFlowContext has _outer_context: the context in which this
590      context is created. Set at the time a context is created. Immutable.
591   4. A ControlFlowContext has _context_stack.
592      Pushed and popped by ctxt.Enter() and ctxt.Exit()
593  """
594
595  def __init__(self, values_def=None, import_scope=None):
596    self._nested_contexts = []
597    self._outer_context = ops.get_default_graph()._get_control_flow_context()
598    if self._outer_context:
599      self._outer_context._nested_contexts.append(self)  # pylint: disable=protected-access
600    self._context_stack = []
601    if values_def:
602      self._init_values_from_proto(values_def, import_scope=import_scope)
603    else:
604      # The names of tensors that have been already seen in this context.
605      self._values = set()
606      # The keys are the names of tensors referenced by but external to this
607      # context. Each value is the Tensor that should be used by this context to
608      # access the key value (e.g. a switch output guarding a cond input value).
609      self._external_values = {}
610
611  def _init_values_from_proto(self, values_def, import_scope=None):
612    """Initializes values and external_values from `ValuesDef` protocol buffer.
613
614    Args:
615      values_def: `ValuesDef` protocol buffer.
616      import_scope: Optional `string`. Name scope to add.
617    """
618    assert isinstance(values_def, control_flow_pb2.ValuesDef)
619    self._values = set(
620        ops.prepend_name_scope(value, import_scope)
621        for value in values_def.values)
622    g = ops.get_default_graph()
623    self._external_values = {}
624    for k, v in values_def.external_values.items():
625      k = ops.prepend_name_scope(k, import_scope)
626      self._external_values[k] = g.as_graph_element(
627          ops.prepend_name_scope(v, import_scope))
628    op_names = set([
629        op.split(":")[0]
630        for op in self._values - set(self._external_values.keys())
631    ])
632    for op in op_names:
633      # pylint: disable=protected-access
634      g.as_graph_element(op)._set_control_flow_context(self)
635      # pylint: enable=protected-access
636
637  @property
638  def name(self):
639    return self._name
640
641  @property
642  def outer_context(self):
643    """Return the context containing this context."""
644    return self._outer_context
645
646  @property
647  def grad_state(self):
648    raise NotImplementedError("Abstract method")
649
650  @property
651  def back_prop(self):
652    raise NotImplementedError("Abstract method")
653
654  @abc.abstractmethod
655  def to_control_flow_context_def(self, context_def, export_scope=None):
656    """Serializes this into `context_def`.
657
658    Args:
659      context_def: a `ControlFlowContextDef` protocol buffer.
660      export_scope: Optional `string`. Name scope to remove.
661    """
662    raise NotImplementedError("Abstract method")
663
664  def _to_values_def(self, export_scope=None):
665    """Converts the values to a `ValuesDef` protocol buffer.
666
667    Args:
668      export_scope: Optional `string`. Name scope to remove.
669
670    Returns:
671      A `ValuesDef` protocol buffer.
672    """
673    values_def = control_flow_pb2.ValuesDef()
674    values_def.values.extend(
675        [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
676    for k, v in self._external_values.items():
677      k = ops.strip_name_scope(k, export_scope)
678      values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
679    return values_def
680
681  def AddName(self, name):
682    self._values.add(name)
683
684  # pylint: disable=protected-access
685  def Enter(self):
686    """Enter this control flow context."""
687    graph = ops.get_default_graph()
688    self._context_stack.append(graph._get_control_flow_context())
689    graph._set_control_flow_context(self)
690
691  def Exit(self):
692    """Exit this control flow context."""
693    graph = ops.get_default_graph()
694    last_context = self._context_stack.pop()
695    graph._set_control_flow_context(last_context)
696
697  def EnterGradientColocation(self, op, gradient_uid):
698    """Start building a gradient colocated with an op."""
699    if self._outer_context:
700      self._outer_context.EnterGradientColocation(op, gradient_uid)
701
702  def ExitGradientColocation(self, op, gradient_uid):
703    """Start building a gradient colocated with an op."""
704    if self._outer_context:
705      self._outer_context.ExitGradientColocation(op, gradient_uid)
706
707  def ExitResult(self, result):
708    """Make a list of tensors available in the outer context."""
709    if self._outer_context:
710      def fn(x):
711        self._outer_context.AddName(x.name)
712        return x
713      nest.map_structure(fn, result, expand_composites=True)
714
715  def GetWhileContext(self):
716    """Return the while context containing this context."""
717    if self._outer_context:
718      return self._outer_context.GetWhileContext()
719    return None
720
721  def _RemoveExternalControlEdges(self, op):
722    """Remove any external control dependency on this op."""
723    while_ctxt = self.GetWhileContext()
724    # A control input of `op` is internal if it is in the same while
725    # loop context as the enclosing while loop context of self.
726    if while_ctxt is None:
727      internal_control_inputs, external_control_inputs = op.control_inputs, []
728    else:
729      internal_control_inputs, external_control_inputs = [], []
730      for x in op.control_inputs:
731        ctxt = util.GetOutputContext(x)
732        if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
733          internal_control_inputs.append(x)
734        else:
735          external_control_inputs.append(x)
736    if len(internal_control_inputs) != len(op.control_inputs):
737      # TODO(mdan): perhaps there should be a replace_control_inputs()
738      op._remove_all_control_inputs()
739      op._add_control_inputs(internal_control_inputs)
740    return internal_control_inputs, external_control_inputs
741
742  # pylint: enable=protected-access
743
744  def AddInnerOp(self, op):
745    """Notifies a scope about an operator added to an inner scope."""
746    if self._outer_context:
747      self._outer_context.AddInnerOp(op)
748
749  def GetControlPivot(self):
750    """Returns the pivot node for this context, or None."""
751    return None
752
753  def IsWhileContext(self):
754    return False
755
756  def IsCondContext(self):
757    return False
758
759  def IsXLAContext(self):
760    return False
761
762  def __str__(self):
763    return self.name
764
765
766class CondContext(ControlFlowContext):
767  """The context for the conditional construct."""
768
769  def __init__(self,
770               pred=None,
771               pivot=None,
772               branch=None,
773               name="cond_text",
774               context_def=None,
775               import_scope=None):
776    """Creates a `CondContext`.
777
778    Args:
779      pred: The `boolean` tensor for the conditional predicate.
780      pivot: The predicate tensor in this branch.
781      branch: 0 or 1 representing this branch.
782      name: Name of the `CondContext` python object.
783      context_def: Optional `ContextDef` protocol buffer to initialize the
784        `CondContext` object from.
785      import_scope: Optional `string`. Name scope to add. Only used when
786        initialing from protocol buffer.
787    """
788    self._name = ops.get_default_graph().unique_name(name)
789
790    if context_def:
791      self._init_from_proto(context_def, import_scope=import_scope)
792    else:
793      # Initializes the default fields.
794      ControlFlowContext.__init__(self)
795      self._pred = pred  # The boolean tensor for the cond predicate
796      self._pivot = pivot  # The predicate tensor in this branch
797      self._branch = branch  # 0 or 1 representing this branch
798
799      # Values considered to have been already seen in this context. pred is not
800      # included in this context.
801      self._values.add(pred.name)
802      self._external_values[pred.name] = pred
803      self._values.add(pivot.name)
804      pivot.op._set_control_flow_context(self)  # pylint: disable=protected-access
805
806  def _init_from_proto(self, context_def, import_scope=None):
807    """Creates a new `CondContext` from protocol buffer.
808
809    Args:
810      context_def: `CondContextDef` protocol buffer.
811      import_scope: Optional `string`. Name scope to add.
812    """
813    assert isinstance(context_def, control_flow_pb2.CondContextDef)
814    # Create from context_def.
815    g = ops.get_default_graph()
816    self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
817    self._pred = g.as_graph_element(
818        ops.prepend_name_scope(context_def.pred_name, import_scope))
819    self._pivot = g.as_graph_element(
820        ops.prepend_name_scope(context_def.pivot_name, import_scope))
821    self._branch = context_def.branch
822    super(CondContext, self).__init__(
823        values_def=context_def.values_def, import_scope=import_scope)
824
825  @property
826  def pred(self):
827    return self._pred
828
829  @property
830  def pivot(self):
831    return self._pivot
832
833  @property
834  def branch(self):
835    return self._branch
836
837  @property
838  def grad_state(self):
839    if self.GetWhileContext():
840      return self.GetWhileContext().grad_state
841    return None
842
843  @property
844  def back_prop(self):
845    if self.GetWhileContext():
846      self.GetWhileContext().back_prop
847    return False
848
849  def GetControlPivot(self):
850    return self._pivot
851
852  def to_proto(self, export_scope=None):
853    """Converts a `CondContext` to a `CondContextDef` protocol buffer.
854
855    Args:
856      export_scope: Optional `string`. Name scope to remove.
857
858    Returns:
859      A `CondContextDef` protocol buffer.
860    """
861    if (export_scope is None or self.name.startswith(export_scope)):
862      context_def = control_flow_pb2.CondContextDef()
863      context_def.context_name = ops.strip_name_scope(self.name, export_scope)
864      context_def.pred_name = ops.strip_name_scope(self._pred.name,
865                                                   export_scope)
866      context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
867                                                    export_scope)
868      context_def.branch = self._branch
869      context_def.values_def.MergeFrom(
870          super(CondContext, self)._to_values_def(export_scope))
871      for nested in self._nested_contexts:
872        nested_def = context_def.nested_contexts.add()
873        nested.to_control_flow_context_def(nested_def)
874
875      return context_def
876    else:
877      return None
878
879  @staticmethod
880  def from_proto(context_def, import_scope=None):
881    """Returns a `CondContext` object created from `context_def`."""
882    ret = CondContext(context_def=context_def, import_scope=import_scope)
883
884    ret.Enter()
885    for nested_def in context_def.nested_contexts:
886      from_control_flow_context_def(nested_def, import_scope=import_scope)
887    ret.Exit()
888    return ret
889
890  def to_control_flow_context_def(self, context_def, export_scope=None):
891    context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
892
893  def AddValue(self, val):
894    """Add `val` to the current context and its outer context recursively."""
895    if val.name in self._values:
896      # Use the real value if it comes from outer context. This is needed in
897      # particular for nested conds.
898      result = self._external_values.get(val.name)
899      result = val if result is None else result
900    else:
901      result = val
902      self._values.add(val.name)
903      if self._outer_context:
904        result = self._outer_context.AddValue(val)
905        self._values.add(result.name)
906        self._external_values[result.name] = result
907      with ops.control_dependencies(None):
908        result = _SwitchRefOrTensor(result, self._pred)[self._branch]
909        if self._outer_context:
910          self._outer_context.AddInnerOp(result.op)
911
912      result.op.graph.prevent_fetching(result.op)
913      # pylint: disable=protected-access
914      result.op._set_control_flow_context(self)
915      # pylint: enable=protected-access
916
917      # Mark Switch output as seen by this context and any outer contexts,
918      # just like what we do for normal op outputs in _AddOpInternal() below.
919      ctxt = self
920      while ctxt is not None:
921        # pylint: disable=protected-access
922        ctxt._values.add(result.name)
923        ctxt = ctxt._outer_context
924        # pylint: enable=protected-access
925
926      self._external_values[val.name] = result
927    return result
928
929  def AddOp(self, op):
930    self._AddOpInternal(op)
931
932  def _AddOpInternal(self, op):
933    """Add `op` to the current context."""
934    if not op.inputs:
935      # If we're in a while loop, remove any control inputs from outside the
936      # loop.
937      self._RemoveExternalControlEdges(op)
938
939      if not any(
940          util.OpInContext(input_op, self) for input_op in op.control_inputs):
941        # pylint: disable=protected-access
942        op._add_control_input(self._pivot.op)
943        # pylint: enable=protected-access
944    else:
945      # Make each input to 'op' available in this CondContext. If an input is
946      # already part of this context there's nothing to do, but if it's
947      # external, AddValue() will handle adding the appropriate Switch node and
948      # other bookkeeping.
949      for index in range(len(op.inputs)):
950        x = op.inputs[index]
951        if op.type == "Merge" and x.op.type == "NextIteration":
952          # Edge case: if we're importing a while loop inside this CondContext,
953          # AddValue() will not correctly handle the NextIteration inputs to
954          # Merge node. The problem is that the NextIteration should also be
955          # part of this context, but if we're importing it won't have been
956          # processed and added to the context yet, so AddValue() will try to
957          # add a Switch which results in an invalid graph. Instead, we use the
958          # NextIteration input as-is here, and it will eventually be added to
959          # the context via AddOp().
960          real_x = x
961        else:
962          real_x = self.AddValue(x)
963        if real_x != x:
964          # pylint: disable=protected-access
965          op._update_input(index, real_x)
966          # pylint: enable=protected-access
967      # Remove any external control dependency on this op.
968      self._RemoveExternalControlEdges(op)
969      # pylint: disable=protected-access
970      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
971        op._add_control_input(self._pivot.op)
972      # pylint: enable=protected-access
973
974    # Mark op's outputs as seen by this context and any outer contexts.
975    output_names = [x.name for x in op.outputs]
976    ctxt = self
977    while ctxt is not None:
978      # pylint: disable=protected-access
979      ctxt._values.update(output_names)
980      ctxt = ctxt._outer_context
981      # pylint: enable=protected-access
982
983    if self._outer_context or not util.IsLoopExit(op):
984      op.graph.prevent_fetching(op)
985
986    if self._outer_context:
987      self._outer_context.AddInnerOp(op)
988
989  def _ProcessOutputTensor(self, val):
990    """Process an output tensor of a conditional branch."""
991    real_val = val
992    if val.name not in self._values:
993      # Handle the special case of lambda: x
994      self._values.add(val.name)
995      if self._outer_context:
996        real_val = self._outer_context.AddValue(val)
997        self._values.add(real_val.name)
998        self._external_values[real_val.name] = real_val
999      real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
1000      self._external_values[val.name] = real_val
1001    else:
1002      external_val = self._external_values.get(val.name)
1003      if external_val is not None:
1004        real_val = external_val
1005    return real_val
1006
1007  def _BuildCondTensor(self, v):
1008    if isinstance(v, ops.Operation):
1009      # Use pivot as the proxy for this op.
1010      return with_dependencies([v], self._pivot)
1011    else:
1012      v = nest.map_structure(
1013          _convert_tensorarray_to_flow, v, expand_composites=True)
1014      return self._ProcessOutputTensor(ops.convert_to_tensor(v))
1015
1016  def BuildCondBranch(self, fn):
1017    """Add the subgraph defined by fn() to the graph."""
1018    pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1019    original_result = fn()
1020    post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1021    if len(post_summaries) > len(pre_summaries):
1022      new_summaries = post_summaries[len(pre_summaries):]
1023      summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1024      summary_ref[:] = pre_summaries
1025      with ops.control_dependencies(new_summaries):
1026        if original_result is None:
1027          return no_op(), None
1028        elif not isinstance(original_result, ops.Operation):
1029          original_result = variable_utils.convert_variables_to_tensors(
1030              original_result)
1031          original_result = nest.map_structure(
1032              array_ops.identity, original_result, expand_composites=True)
1033    if original_result is None:
1034      return None, None
1035
1036    original_result = variable_utils.convert_variables_to_tensors(
1037        original_result)
1038    result = nest.map_structure(
1039        self._BuildCondTensor, original_result, expand_composites=True)
1040    if not isinstance(result, (list, _basetuple)):
1041      result = [result]
1042    return original_result, result
1043
1044  def IsCondContext(self):
1045    return True
1046
1047
1048def _UnpackIfSingleton(res):
1049  if isinstance(res, (list, _basetuple)) and len(res) == 1:
1050    return res[0]
1051  else:
1052    return res
1053
1054
1055def _eager_cond_implementation(pred, true_fn, false_fn, strict, name):
1056  """Special cases for `cond` when executing eagerly."""
1057  pred = ops.convert_to_tensor(pred)
1058  pred_constant_value = tensor_util.constant_value(pred)
1059  if pred_constant_value is None:
1060    # Eager tensors from a parallel device may not have a constant
1061    # value. Running the cond op itself would work, but we don't have logic to
1062    # build cond ops without wrapping in a function first.
1063    if (not isinstance(true_fn, def_function.Function)
1064        or not isinstance(false_fn, def_function.Function)):
1065      raise TypeError("When running tf.cond on a parallel device, 'true_fn' "
1066                      "and 'false_fn' must be decorated with `tf.function`.")
1067    @def_function.function
1068    def _parallel_device_cond_wrapper():
1069      return cond_v2.cond_v2(pred, true_fn, false_fn, name)
1070    functions_run_eagerly = def_function.functions_run_eagerly()
1071    if functions_run_eagerly:
1072      # We need to use tf.function to deal with variable creation inside the
1073      # cond, and skipping it because of run_functions_eagerly would just
1074      # crash immediately.
1075      logging.warning(
1076          "It looks like tf.function behavior was disabled, perhaps using "
1077          "tf.config.run_functions_eagerly. Parallelized tf.cond requires "
1078          "tf.function to work. This primitive will override the disable.")
1079    def_function.run_functions_eagerly(False)
1080    try:
1081      return _parallel_device_cond_wrapper()
1082    finally:
1083      if functions_run_eagerly is not None:
1084        def_function.run_functions_eagerly(functions_run_eagerly)
1085  else:
1086    # For conditions which are eager tensors with a constant value (most of
1087    # them), we only call the relevant branch function and execute it eagerly.
1088    with ops.name_scope(name, "cond", [pred]):
1089      if pred_constant_value:
1090        result = true_fn()
1091      else:
1092        result = false_fn()
1093      if not strict:
1094        result = _UnpackIfSingleton(result)
1095      return result
1096
1097
1098# pylint: disable=redefined-outer-name
1099# pylint: disable=g-doc-args
1100@tf_export(v1=["cond"])
1101@dispatch.add_dispatch_support
1102@deprecation.deprecated_args(
1103    None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
1104    "fn1", "fn2")
1105def cond(pred,
1106         true_fn=None,
1107         false_fn=None,
1108         strict=False,
1109         name=None,
1110         fn1=None,
1111         fn2=None):
1112  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
1113
1114  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
1115  `false_fn` must have the same non-zero number and type of outputs.
1116
1117  **WARNING**: Any Tensors or Operations created outside of `true_fn` and
1118  `false_fn` will be executed regardless of which branch is selected at runtime.
1119
1120  Although this behavior is consistent with the dataflow model of TensorFlow,
1121  it has frequently surprised users who expected a lazier semantics.
1122  Consider the following simple program:
1123
1124  ```python
1125  z = tf.multiply(a, b)
1126  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
1127  ```
1128
1129  If `x < y`, the `tf.add` operation will be executed and `tf.square`
1130  operation will not be executed. Since `z` is needed for at least one
1131  branch of the `cond`, the `tf.multiply` operation is always executed,
1132  unconditionally.
1133
1134  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
1135  call to `cond`, and not at all during `Session.run()`). `cond`
1136  stitches together the graph fragments created during the `true_fn` and
1137  `false_fn` calls with some additional graph nodes to ensure that the right
1138  branch gets executed depending on the value of `pred`.
1139
1140  `tf.cond` supports nested structures as implemented in
1141  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
1142  same (possibly nested) value structure of lists, tuples, and/or named tuples.
1143  Singleton lists and tuples form the only exceptions to this: when returned by
1144  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
1145  This behavior is disabled by passing `strict=True`.
1146
1147  Args:
1148    pred: A scalar determining whether to return the result of `true_fn` or
1149      `false_fn`.
1150    true_fn: The callable to be performed if pred is true.
1151    false_fn: The callable to be performed if pred is false.
1152    strict: A boolean that enables/disables 'strict' mode; see above.
1153    name: Optional name prefix for the returned tensors.
1154
1155  Returns:
1156    Tensors returned by the call to either `true_fn` or `false_fn`. If the
1157    callables return a singleton list, the element is extracted from the list.
1158
1159  Raises:
1160    TypeError: if `true_fn` or `false_fn` is not callable.
1161    ValueError: if `true_fn` and `false_fn` do not return the same number of
1162      tensors, or return tensors of different types.
1163
1164  Example:
1165
1166  ```python
1167  x = tf.constant(2)
1168  y = tf.constant(5)
1169  def f1(): return tf.multiply(x, 17)
1170  def f2(): return tf.add(y, 23)
1171  r = tf.cond(tf.less(x, y), f1, f2)
1172  # r is set to f1().
1173  # Operations in f2 (e.g., tf.add) are not executed.
1174  ```
1175
1176  """
1177  # We needed to make true_fn/false_fn keyword arguments for
1178  # backwards-compatibility. This check exists so that we can convert back to
1179  # having them be positional arguments.
1180  # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
1181  # `fn1` and `fn2` are deleted.
1182  if fn1 is not None:
1183    if true_fn is not None:
1184      raise TypeError(
1185          "cond(): 'true_fn' and 'fn1' may not be set simultaneously.")
1186    true_fn = fn1
1187  elif true_fn is None:
1188    raise TypeError("cond(): 'true_fn' argument required")
1189  if fn2 is not None:
1190    if false_fn is not None:
1191      raise TypeError(
1192          "cond(): 'false_fn' and 'fn2' may not be set simultaneously.")
1193    false_fn = fn2
1194  elif false_fn is None:
1195    raise TypeError("cond(): 'false_fn' argument required")
1196
1197  if not callable(true_fn):
1198    raise TypeError("'true_fn' must be callable.")
1199  if not callable(false_fn):
1200    raise TypeError("'false_fn' must be callable.")
1201
1202  if context.executing_eagerly():
1203    return _eager_cond_implementation(pred, true_fn, false_fn, strict, name)
1204
1205  # Always enable control flow v2 if building a function, regardless of toggle.
1206  if util.EnableControlFlowV2(ops.get_default_graph()):
1207    return cond_v2.cond_v2(pred, true_fn, false_fn, name)
1208
1209  with ops.name_scope(name, "cond", [pred]):
1210    # Add the Switch to the graph.
1211    if isinstance(pred, bool):
1212      raise TypeError("'pred' must not be a Python bool.")
1213    p_2, p_1 = switch(pred, pred)
1214    pivot_1 = array_ops.identity(p_1, name="switch_t")
1215    pivot_2 = array_ops.identity(p_2, name="switch_f")
1216    pred = array_ops.identity(pred, name="pred_id")
1217    # Disable the fetching of tensors that are only on one branch of cond.
1218    for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:
1219      tensor.op.graph.prevent_fetching(tensor.op)
1220
1221    # Build the graph for the true branch in a new context.
1222    context_t = CondContext(pred, pivot_1, branch=1)
1223    try:
1224      context_t.Enter()
1225      orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
1226      if orig_res_t is None:
1227        raise ValueError("'true_fn' must have a return value.")
1228      context_t.ExitResult(res_t)
1229    finally:
1230      context_t.Exit()
1231
1232    # Build the graph for the false branch in a new context.
1233    context_f = CondContext(pred, pivot_2, branch=0)
1234    try:
1235      context_f.Enter()
1236      orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
1237      if orig_res_f is None:
1238        raise ValueError("'false_fn' must have a return value.")
1239      context_f.ExitResult(res_f)
1240    finally:
1241      context_f.Exit()
1242
1243    if not strict:
1244      orig_res_t = _UnpackIfSingleton(orig_res_t)
1245      orig_res_f = _UnpackIfSingleton(orig_res_f)
1246
1247    # Check that the return values of the two branches have the same structure.
1248    try:
1249      nest.assert_same_structure(orig_res_t, orig_res_f, expand_composites=True)
1250    except (TypeError, ValueError):
1251      nest.map_structure(_cast_indexed_slice_indices, orig_res_t, orig_res_f)
1252      nest.map_structure(_cast_indexed_slice_indices, res_t, res_f)
1253      try:
1254        nest.assert_same_structure(orig_res_t, orig_res_f,
1255                                   expand_composites=True)
1256      except TypeError as e:
1257        raise TypeError(
1258            f"Incompatible return types of 'true_fn' and 'false_fn': {e}")
1259      except ValueError as e:
1260        raise ValueError(
1261            f"Incompatible return values of 'true_fn' and 'false_fn': {e}")
1262
1263    # Add the final merge to the graph.
1264    if not res_t:
1265      raise ValueError(
1266          "'true_fn' and 'false_fn' must return at least one result.")
1267
1268    res_t_flat = nest.flatten(res_t, expand_composites=True)
1269    res_f_flat = nest.flatten(res_f, expand_composites=True)
1270
1271    for (x, y) in zip(res_t_flat, res_f_flat):
1272      assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
1273      if x.dtype.base_dtype != y.dtype.base_dtype:
1274        raise ValueError(
1275            "Outputs of 'true_fn' and 'false_fn' must have the same type(s). "
1276            f"Received {x.dtype.name} from 'true_fn' "
1277            f"and {y.dtype.name} from 'false_fn'.")
1278
1279    merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
1280    merges = nest.map_structure(
1281        _convert_flow_to_tensorarray,
1282        nest.flatten(orig_res_t, expand_composites=True),
1283        merges)
1284
1285    # Only add non-nested conds to the collection. Any nested control flow will
1286    # be encapsulated in the root context.
1287    assert context_t.outer_context == context_f.outer_context
1288    if context_t.outer_context is None:
1289      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
1290      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
1291
1292    merges = nest.pack_sequence_as(
1293        structure=orig_res_t, flat_sequence=merges, expand_composites=True)
1294
1295    # Singleton lists and tuples are automatically unpacked if strict == False.
1296    if not strict:
1297      merges = _UnpackIfSingleton(merges)
1298    return merges
1299
1300
1301def _cast_indexed_slice_indices(a, b):
1302  """Cast IndexedSlice.indices from int32 to int64 where necessary.
1303
1304  If `a` and `b` are both IndexedSlices, and their indices have different
1305  dtypes, then cast both their dtypes to `int64` (modifies `a` and `b`
1306  in-place).  Otherwise, does nothing.
1307
1308  Args:
1309    a: A value, which may be an IndexedSlices.
1310    b: A value, which may be an IndexedSlices.
1311  """
1312  if (isinstance(a, indexed_slices.IndexedSlices) and
1313      isinstance(b, indexed_slices.IndexedSlices) and
1314      a.indices.dtype != b.indices.dtype):
1315    # pylint: disable=protected-access
1316    a._indices = math_ops.cast(a.indices, dtypes.int64)
1317    b._indices = math_ops.cast(b.indices, dtypes.int64)
1318
1319
1320# pylint: enable=g-doc-args
1321# pylint: enable=redefined-outer-name
1322
1323
1324@tf_export("cond", v1=[])
1325@dispatch.add_dispatch_support
1326def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None):
1327  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
1328
1329  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
1330  `false_fn` must have the same non-zero number and type of outputs.
1331
1332  **WARNING**: Any Tensors or Operations created outside of `true_fn` and
1333  `false_fn` will be executed regardless of which branch is selected at runtime.
1334
1335  Although this behavior is consistent with the dataflow model of TensorFlow,
1336  it has frequently surprised users who expected a lazier semantics.
1337  Consider the following simple program:
1338
1339  ```python
1340  z = tf.multiply(a, b)
1341  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
1342  ```
1343
1344  If `x < y`, the `tf.add` operation will be executed and `tf.square`
1345  operation will not be executed. Since `z` is needed for at least one
1346  branch of the `cond`, the `tf.multiply` operation is always executed,
1347  unconditionally.
1348
1349  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
1350  call to `cond`, and not at all during `Session.run()`). `cond`
1351  stitches together the graph fragments created during the `true_fn` and
1352  `false_fn` calls with some additional graph nodes to ensure that the right
1353  branch gets executed depending on the value of `pred`.
1354
1355  `tf.cond` supports nested structures as implemented in
1356  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
1357  same (possibly nested) value structure of lists, tuples, and/or named tuples.
1358  Singleton lists and tuples form the only exceptions to this: when returned by
1359  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
1360
1361  Note: It is illegal to "directly" use tensors created inside a cond branch
1362  outside it, e.g. by storing a reference to a branch tensor in the python
1363  state. If you need to use a tensor created in a branch function you should
1364  return it as an output of the branch function and use the output from
1365  `tf.cond` instead.
1366
1367  Args:
1368    pred: A scalar determining whether to return the result of `true_fn` or
1369      `false_fn`.
1370    true_fn: The callable to be performed if pred is true.
1371    false_fn: The callable to be performed if pred is false.
1372    name: Optional name prefix for the returned tensors.
1373
1374  Returns:
1375    Tensors returned by the call to either `true_fn` or `false_fn`. If the
1376    callables return a singleton list, the element is extracted from the list.
1377
1378  Raises:
1379    TypeError: if `true_fn` or `false_fn` is not callable.
1380    ValueError: if `true_fn` and `false_fn` do not return the same number of
1381      tensors, or return tensors of different types.
1382
1383  Example:
1384
1385  ```python
1386  x = tf.constant(2)
1387  y = tf.constant(5)
1388  def f1(): return tf.multiply(x, 17)
1389  def f2(): return tf.add(y, 23)
1390  r = tf.cond(tf.less(x, y), f1, f2)
1391  # r is set to f1().
1392  # Operations in f2 (e.g., tf.add) are not executed.
1393  ```
1394
1395  """
1396  return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
1397
1398
1399def _resource_safe_shape(t):
1400  """Returns the shape of t or the variable it points to."""
1401  if t.dtype == dtypes.resource:
1402    while t.op.inputs:
1403      t = t.op.inputs[0]
1404    return tensor_shape.TensorShape(t.op.get_attr("shape"))
1405  return array_ops.shape_internal(t, optimize=False)
1406
1407
1408# TODO(yuanbyu): Consider having a unified notion of context for
1409# not only conditionals and loops but also control dependency and
1410# subgraphs.
1411class WhileContext(ControlFlowContext):
1412  """The context for the loop construct."""
1413
1414  def __init__(self,
1415               maximum_iterations=None,
1416               parallel_iterations=10,
1417               back_prop=True,
1418               swap_memory=False,
1419               name="while_context",
1420               grad_state=None,
1421               context_def=None,
1422               import_scope=None):
1423    """"Creates a `WhileContext`.
1424
1425    Args:
1426      maximum_iterations: Optional upper bound on number of loop iterations.
1427      parallel_iterations: The number of iterations allowed to run in parallel.
1428      back_prop: Whether backprop is enabled for this while loop.
1429      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
1430      name: Optional name prefix for the returned tensors.
1431      grad_state: The gradient loop state.
1432      context_def: Optional `WhileContextDef` protocol buffer to initialize the
1433        `Whilecontext` python object from.
1434      import_scope: Optional `string`. Name scope to add. Only used when
1435        initialing from protocol buffer.
1436    """
1437    if context_def:
1438      self._init_from_proto(context_def, import_scope=import_scope)
1439    else:
1440      ControlFlowContext.__init__(self)
1441      self._init_from_args(maximum_iterations, parallel_iterations, back_prop,
1442                           swap_memory, name)
1443    # The gradient loop state.
1444    self._grad_state = grad_state
1445
1446  def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop,
1447                      swap_memory, name):
1448    """Creates a new `WhileContext` from arguments.
1449
1450    Args:
1451      maximum_iterations: Optional upper bound on number of loop iterations.
1452      parallel_iterations: The number of iterations allowed to run in parallel.
1453      back_prop: Whether backprop is enabled for this while loop.
1454      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
1455      name: Optional name prefix for the returned tensors.
1456
1457    Raises:
1458      ValueError: If `parallel_iterations` has invalid value.
1459    """
1460    if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
1461      raise ValueError("'parallel_iterations' must be a positive integer: "
1462                       "%s" % parallel_iterations)
1463    self._name = ops.get_default_graph().unique_name(name)
1464    self._maximum_iterations = maximum_iterations
1465    self._parallel_iterations = parallel_iterations
1466    self._back_prop = back_prop
1467    self._swap_memory = swap_memory
1468    # We use this node to control constants created by the pred lambda.
1469    self._pivot_for_pred = None
1470    # We use this node to control constants created by the body lambda.
1471    self._pivot_for_body = None
1472    # The boolean tensor for loop termination condition. Used in code
1473    # generation for gradient computation
1474    self._pivot = None
1475    # The list of exit tensors for loop variables.
1476    self._loop_exits = []
1477    # The list of enter tensors for loop variables.
1478    self._loop_enters = []
1479    self._graph = ops.get_default_graph()
1480
1481  def _init_from_proto(self, context_def, import_scope=None):
1482    """Creates a new `WhileContext` from protocol buffer.
1483
1484    Args:
1485      context_def: `WhileContextDef` protocol buffer.
1486      import_scope: Optional `string`. Name scope to add.
1487    """
1488    assert isinstance(context_def, control_flow_pb2.WhileContextDef)
1489    # Create from context_def.
1490    g = ops.get_default_graph()
1491    self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
1492    if context_def.maximum_iterations_name:
1493      self._maximum_iterations = g.as_graph_element(
1494          ops.prepend_name_scope(context_def.maximum_iterations_name,
1495                                 import_scope))
1496    else:
1497      self._maximum_iterations = None
1498    self._parallel_iterations = context_def.parallel_iterations
1499    self._back_prop = context_def.back_prop
1500    self._swap_memory = context_def.swap_memory
1501    self._pivot_for_pred = g.as_graph_element(
1502        ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
1503    # We use this node to control constants created by the body lambda.
1504    self._pivot_for_body = g.as_graph_element(
1505        ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
1506    # The boolean tensor for loop termination condition. Used in code
1507    # generation for gradient computation.
1508    self._pivot = g.as_graph_element(
1509        ops.prepend_name_scope(context_def.pivot_name, import_scope))
1510    # The list of exit tensors for loop variables.
1511    self._loop_exits = [
1512        g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
1513        for exit_name in context_def.loop_exit_names
1514    ]
1515    # The list of enter tensors for loop variables.
1516    self._loop_enters = [
1517        g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
1518        for enter_name in context_def.loop_enter_names
1519    ]
1520    super(WhileContext, self).__init__(
1521        values_def=context_def.values_def, import_scope=import_scope)
1522
1523    # import_scope causes self.name to be different from the original serialized
1524    # context's name. Rewrite "frame_name" attrs with the new name.
1525    if import_scope:
1526      for tensor_name in self._values:
1527        op = g.as_graph_element(tensor_name).op
1528        if util.IsLoopEnter(op):
1529          # pylint: disable=protected-access
1530          op._set_attr("frame_name",
1531                       attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
1532          # pylint: enable=protected-access
1533    self._graph = ops.get_default_graph()
1534
1535  @property
1536  def maximum_iterations(self):
1537    """The maximum number of iterations that will be executed."""
1538    return self._maximum_iterations
1539
1540  @property
1541  def parallel_iterations(self):
1542    """The number of iterations allowed to run in parallel."""
1543    return self._parallel_iterations
1544
1545  @property
1546  def back_prop(self):
1547    """True iff backprop is enabled for this while loop."""
1548    return self._back_prop
1549
1550  @property
1551  def swap_memory(self):
1552    """True iff GPU-CPU memory swap is enabled for this while loop."""
1553    return self._swap_memory
1554
1555  @property
1556  def pivot(self):
1557    """The boolean tensor representing the loop termination condition."""
1558    return self._pivot
1559
1560  @property
1561  def loop_enters(self):
1562    """The list of enter tensors for loop variables."""
1563    return self._loop_enters
1564
1565  @property
1566  def loop_exits(self):
1567    """The list of exit tensors for loop variables."""
1568    return self._loop_exits
1569
1570  @property
1571  def grad_state(self):
1572    """The gradient loop state."""
1573    return self._grad_state
1574
1575  def to_proto(self, export_scope=None):
1576    """Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
1577
1578    Args:
1579      export_scope: Optional `string`. Name scope to remove.
1580
1581    Returns:
1582      A `WhileContextDef` protocol buffer.
1583    """
1584    if (export_scope is None or self.name.startswith(export_scope)):
1585      context_def = control_flow_pb2.WhileContextDef()
1586      context_def.context_name = ops.strip_name_scope(self.name, export_scope)
1587      context_def.parallel_iterations = self._parallel_iterations
1588      if self._maximum_iterations is not None:
1589        context_def.maximum_iterations_name = ops.strip_name_scope(
1590            self._maximum_iterations.name, export_scope)
1591      context_def.back_prop = self._back_prop
1592      context_def.swap_memory = self._swap_memory
1593      context_def.pivot_for_pred_name = ops.strip_name_scope(
1594          self._pivot_for_pred.name, export_scope)
1595      context_def.pivot_for_body_name = ops.strip_name_scope(
1596          self._pivot_for_body.name, export_scope)
1597      context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
1598                                                    export_scope)
1599      context_def.loop_exit_names.extend([
1600          ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
1601      ])
1602      context_def.loop_enter_names.extend([
1603          ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
1604      ])
1605      context_def.values_def.MergeFrom(
1606          super(WhileContext, self)._to_values_def(export_scope=export_scope))
1607      for nested in self._nested_contexts:
1608        nested_def = context_def.nested_contexts.add()
1609        nested.to_control_flow_context_def(nested_def)
1610
1611      return context_def
1612    else:
1613      return None
1614
1615  def to_control_flow_context_def(self, context_def, export_scope=None):
1616    context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
1617
1618  @staticmethod
1619  def from_proto(context_def, import_scope=None):
1620    """Returns a `WhileContext` object created from `context_def`.
1621
1622    Args:
1623      context_def: A `WhileContextDef` protocol buffer.
1624      import_scope: Optional `string`. Name scope to add.
1625
1626    Returns:
1627      A `WhileContext` Python object.
1628    """
1629    ret = WhileContext(context_def=context_def, import_scope=import_scope)
1630    ret.Enter()
1631    for nested_def in context_def.nested_contexts:
1632      from_control_flow_context_def(nested_def, import_scope=import_scope)
1633    ret.Exit()
1634    return ret
1635
1636  def GetWhileContext(self):
1637    return self
1638
1639  def GetControlPivot(self):
1640    if self._pivot_for_body is not None:
1641      return self._pivot_for_body
1642    return self._pivot_for_pred
1643
1644  def AddValue(self, val):
1645    """Add `val` to the current context and its outer context recursively."""
1646    result = val
1647    new_value = val.name not in self._values
1648    # Don't treat ops in this context as new values. Usually all known values
1649    # are in self._values, except when we're importing a while loop inside this
1650    # WhileContext. Since there's a cycle in this case, `val` may be part of the
1651    # imported while loop but not yet processed by this context and added to
1652    # self._values in _AddOpInternal. We only want to process external input
1653    # tensors to the while loop here.
1654    new_value &= val.op._control_flow_context is not self  # pylint: disable=protected-access
1655    if new_value:
1656      self._values.add(val.name)
1657
1658      # If we are in a grad context and val is from its forward context,
1659      # use GetRealValue(), which adds the logic to save the history of
1660      # val in forward.
1661      grad_ctxt = ops.get_default_graph()._get_control_flow_context()
1662      if grad_ctxt:
1663        grad_ctxt = grad_ctxt.GetWhileContext()
1664        if grad_ctxt.grad_state:
1665          forward_ctxt = util.GetWhileContext(val.op)
1666          if util.IsLoopExit(val.op):
1667            forward_ctxt = forward_ctxt.outer_context
1668            if forward_ctxt:
1669              forward_ctxt = forward_ctxt.GetWhileContext()
1670          if forward_ctxt == grad_ctxt.grad_state.forward_context:
1671            real_val = grad_ctxt.grad_state.GetRealValue(val)
1672            self._external_values[val.name] = real_val
1673            return real_val
1674
1675      if self._outer_context is not None:
1676        result = self._outer_context.AddValue(val)
1677      # Create an Enter to make `result` known to this loop context.
1678      with ops.control_dependencies(None):
1679        enter = _Enter(
1680            result,
1681            self._name,
1682            is_constant=True,
1683            parallel_iterations=self._parallel_iterations)
1684        enter.graph.prevent_feeding(enter)
1685        if self._outer_context:
1686          self._outer_context.AddInnerOp(enter.op)
1687      # Fix the control inputs and control flow context of these enter ops.
1688      self._FixControlInputsAndContext([enter])
1689
1690      # Add `enter` in this context.
1691      self._values.add(enter.name)
1692      self._external_values[val.name] = enter
1693      result = enter
1694    else:
1695      actual_val = self._external_values.get(val.name)
1696      if actual_val is not None:
1697        result = actual_val
1698    return result
1699
1700  def AddOp(self, op):
1701    """Add `op` to the current context."""
1702    # For a reduction op, if op is in a grad context and its input is from
1703    # its forward context, moving op to the forward context means we would
1704    # store the tensor after the reduction as opposed to the tensor before
1705    # reduction, and therefore could significantly reduce memory consumption.
1706    # For now, we do this only for a few ops.
1707    #
1708    # If in XLA context, do not move constant ops to forward pass as pushing to
1709    # and popping from a stack removes the constant property of an op and breaks
1710    # XLA compilation, which requires certain inputs to be constant for certain
1711    # ops.
1712    if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}:
1713      grad_ctxt = ops.get_default_graph()._get_control_flow_context()
1714      if grad_ctxt:
1715        grad_ctxt = grad_ctxt.GetWhileContext()
1716        if grad_ctxt.grad_state:
1717          op_input_forward_ctxt = util.GetWhileContext(op.inputs[0].op)
1718          if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
1719            op_input_ctxt = op.inputs[0].op._get_control_flow_context()
1720            op._set_control_flow_context(op_input_ctxt)
1721            op_input_ctxt._AddOpInternal(op)
1722            return
1723    self._AddOpInternal(op)
1724
1725  def _AddOpInternal(self, op):
1726    """Add `op` to the current context.
1727
1728    We move any external control dependencies of the op to the loop pivot, to
1729    ensure they get executed.
1730    """
1731    # This is needed to prevent frame mismatch errors where there are Const
1732    # nodes inside tf.function in v1 while_loop and inlining is turned on.
1733    if op.type in ["PartitionedCall", "StatefulPartitionedCall"]:
1734      op._add_control_input(self.GetControlPivot().op)  # pylint: disable=protected-access
1735    if not op.inputs:
1736      # Remove any external control dependency on this op
1737      control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
1738      # Add a control edge from the control pivot to this op.
1739      if not control_inputs:
1740        # pylint: disable=protected-access
1741        op._add_control_input(self.GetControlPivot().op)
1742        # pylint: enable=protected-access
1743      for x in op.outputs:
1744        self._values.add(x.name)
1745    else:
1746      for index in range(len(op.inputs)):
1747        x = op.inputs[index]
1748        real_x = self.AddValue(x)
1749        if real_x != x:
1750          op._update_input(index, real_x)  # pylint: disable=protected-access
1751      # Remove any external control dependency on this op.
1752      _, external_inputs = self._RemoveExternalControlEdges(op)
1753      # Add a control dependency to prevent loop invariants from
1754      # enabling ops that should not be executed.
1755      self._MaybeAddControlDependency(op)
1756      for x in op.outputs:
1757        self._values.add(x.name)
1758    if external_inputs:
1759      # Use an identity to pull control inputs as data inputs. Note that we
1760      # ignore ops which don't have outputs. TODO(apassos): fix that
1761      with ops.control_dependencies(None):
1762        self.Enter()
1763        external_inputs = [
1764            array_ops.identity(x.outputs[0]).op
1765            for x in external_inputs
1766            if x.outputs
1767        ]
1768        self.Exit()
1769      op._add_control_inputs(external_inputs)  # pylint: disable=protected-access
1770    if self._outer_context or not util.IsLoopExit(op):
1771      op.graph.prevent_fetching(op)
1772      for x in op.outputs:
1773        op.graph.prevent_feeding(x)
1774
1775    if self._outer_context:
1776      self._outer_context.AddInnerOp(op)
1777
1778  def _MaybeAddControlDependency(self, op):
1779    """Add a control input to the op if it only depends on loop invariants."""
1780
1781    def _IsOpFree(op):
1782      """Determines if `op` needs a control dependency."""
1783      if op.control_inputs:
1784        return False
1785      # pylint: disable=protected-access
1786      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
1787        return True
1788      # pylint: enable=protected-access
1789      for x in op.inputs:
1790        if not util.IsLoopConstantEnter(x.op):
1791          return False
1792      return True
1793
1794    if _IsOpFree(op):
1795      # pylint: disable=protected-access
1796      op._add_control_input(self.GetControlPivot().op)
1797      # pylint: enable=protected-access
1798
1799  def AddForwardLoopCounter(self, outer_grad_state):
1800    """Adds a loop that counts the number of iterations.
1801
1802    This is added to the forward loop at the time when we start to
1803    create the loop for backprop gradient computation. Called in
1804    the outer context of this forward context.
1805
1806    The pseudocode is:
1807      `n = 0; while (_pivot) { n++; }`
1808
1809    Note that a control dependency is added to `n` to ensure the correct
1810    execution order of stack push ops.
1811
1812    Args:
1813      outer_grad_state: The outer grad state. None if not nested.
1814
1815    Returns:
1816      The number of iterations taken by the forward loop and the loop index.
1817    """
1818    n = constant_op.constant(0, name="f_count")
1819    if outer_grad_state is not None:
1820      # Force the stack pushes of i-th execution of an inner loop to be ordered
1821      # before the pushes of (i+1)-th execution of the same inner loop.
1822      outer_add_op = outer_grad_state.forward_index.op.inputs[0].op
1823      n.op._add_control_input(outer_add_op)  # pylint: disable=protected-access
1824
1825    self.Enter()
1826    self.AddName(n.name)
1827    enter_n = _Enter(
1828        n,
1829        self._name,
1830        is_constant=False,
1831        parallel_iterations=self._parallel_iterations,
1832        name="f_count")
1833    self.loop_enters.append(enter_n)
1834
1835    merge_n = merge([enter_n, enter_n])[0]
1836    switch_n = switch(merge_n, self._pivot)
1837
1838    index = math_ops.add(switch_n[1], 1)
1839    next_n = _NextIteration(index)
1840    merge_n.op._update_input(1, next_n)
1841
1842    total_iterations = exit(switch_n[0], name="f_count")
1843    self.loop_exits.append(total_iterations)
1844    self.ExitResult([total_iterations])
1845    self.Exit()
1846    return total_iterations, next_n
1847
1848  def AddBackpropLoopCounter(self, count, outer_grad_state):
1849    """Add the backprop loop that controls the iterations.
1850
1851    This is added to the backprop loop. It is used to control the loop
1852    termination of the backprop loop. Called in the outer context of
1853    this grad context.
1854
1855    The pseudocode is:
1856      `n = count; while (n >= 1) { n--; }`
1857
1858    Note that a control dependency is added to `final_zero` to ensure the
1859    correct execution order of stack pop ops.
1860
1861    Args:
1862      count: The number of iterations for backprop.
1863      outer_grad_state: The outer grad state. None if not nested.
1864
1865    Returns:
1866      The loop index.
1867    """
1868    in_separate_functions = count.graph is not ops.get_default_graph()
1869    if in_separate_functions:
1870      # Brings the count into this graph
1871      count = array_ops.identity(count)
1872    else:
1873      # TODO(apassos) XLA expects this constant to be created outside the loop,
1874      # so doing that for now.
1875      one = constant_op.constant(1, name="b_count")
1876
1877    self.Enter()
1878    self.AddName(count.name)
1879    enter_count = _Enter(
1880        count,
1881        self._name,
1882        is_constant=False,
1883        parallel_iterations=self._parallel_iterations,
1884        name="b_count")
1885    self.loop_enters.append(enter_count)
1886
1887    merge_count = merge([enter_count, enter_count])[0]
1888    self._pivot_for_pred = merge_count
1889
1890    if in_separate_functions:
1891      one = constant_op.constant(1, name="b_count")
1892    pred = math_ops.greater_equal(merge_count, one)
1893    self._pivot = loop_cond(pred, name="b_count")
1894    switch_count = switch(merge_count, self._pivot)
1895
1896    index = math_ops.subtract(switch_count[1], one)
1897    self._pivot_for_body = index
1898    next_count = _NextIteration(index)
1899    merge_count.op._update_input(1, next_count)
1900
1901    final_zero = exit(switch_count[0], name="b_count")
1902    self.loop_exits.append(final_zero)
1903    if outer_grad_state is not None:
1904      # Force the stack pops of i-th execution of an inner loop to be ordered
1905      # before the pops of (i+1)-th execution of the same inner loop.
1906      # pylint: disable=protected-access
1907      outer_grad_state.grad_sync._add_control_input(final_zero.op)
1908      # pylint: enable=protected-access
1909
1910    self.ExitResult([final_zero])
1911    self.Exit()
1912    return next_count
1913
1914  def AddBackpropAccumulator(self, op, grad):
1915    """Add an accumulation loop for every loop invariant.
1916
1917    This is added to the backprop loop. It is used to accumulate partial
1918    gradients within each loop iteration. Called when in the gradient while
1919    context.
1920
1921    The pseudocode is:
1922      ```
1923      acc = 0.0;
1924      while (_pivot) {
1925        acc += grad;
1926      }
1927      ```
1928
1929    Args:
1930      op: The Enter op for a loop invariant.
1931      grad: The partial gradient of an iteration for a loop invariant.
1932
1933    Returns:
1934      The gradient for a loop invariant.
1935    """
1936    self.Exit()
1937    # Create a zeros tensor with the right shape for acc. If we don't
1938    # know the full shape statically, we will have to get the shape
1939    # dynamically from the forward inference. Getting the shape right
1940    # for the zeros is only needed for the base case when the loop exits
1941    # without running any iterations.
1942    shape = grad.get_shape()
1943    if shape.is_fully_defined():
1944      if self.outer_context:
1945        self.outer_context.Enter()
1946      acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
1947      if self.outer_context:
1948        self.outer_context.Exit()
1949    else:
1950      value = op.inputs[0]
1951      if (isinstance(self.outer_context, WhileContext) and
1952          self.outer_context.grad_state is not None):
1953        # We are in a nested while loop.
1954        forward_ctxt = self.grad_state.forward_context
1955        forward_ctxt.outer_context.Enter()
1956        zeros_shape = array_ops.shape_internal(value, optimize=False)
1957        forward_ctxt.outer_context.Exit()
1958        outer_grad_state = self.grad_state.outer_grad_state
1959        history_zeros_shape = outer_grad_state.AddForwardAccumulator(
1960            zeros_shape)
1961        self.outer_context.Enter()
1962        real_shape = outer_grad_state.AddBackpropAccumulatedValue(
1963            history_zeros_shape, zeros_shape)
1964        acc = array_ops.zeros(real_shape, grad.dtype)
1965        self.outer_context.Exit()
1966      else:
1967        if self.outer_context:
1968          self.outer_context.Enter()
1969        zeros_shape = array_ops.shape_internal(value, optimize=False)
1970        acc = array_ops.zeros(zeros_shape, grad.dtype)
1971        if self.outer_context:
1972          self.outer_context.Exit()
1973
1974    self.Enter()
1975    self.AddName(acc.name)
1976    enter_acc = _Enter(
1977        acc,
1978        self._name,
1979        is_constant=False,
1980        parallel_iterations=self._parallel_iterations,
1981        name="b_acc")
1982    self.loop_enters.append(enter_acc)
1983
1984    merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
1985    switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
1986
1987    add_acc = math_ops.add(switch_acc_true, grad)
1988    next_acc = _NextIteration(add_acc)
1989    merge_acc.op._update_input(1, next_acc)  # pylint: disable=protected-access
1990
1991    result_acc = exit(switch_acc_false, name="b_acc")
1992    self.loop_exits.append(result_acc)
1993    self.ExitResult([result_acc])
1994    return result_acc
1995
1996  def AddBackpropIndexedSlicesAccumulator(self, op, grad):
1997    """This is used for accumulating gradients that are IndexedSlices.
1998
1999    This is essentially the equivalent of AddBackpropAccumulator but optimized
2000    for things like updating embeddings from within a while loop.
2001
2002    Args:
2003      op: The Enter op for a loop invariant.
2004      grad: The partial gradients represented as an IndexedSlices.
2005
2006    Returns:
2007      The accumulated IndexedSlices gradient of the loop invariant.
2008    """
2009    values = grad.values
2010    indices = grad.indices
2011    dense_shape = grad.dense_shape
2012
2013    self.Exit()
2014    if self.outer_context:
2015      self.outer_context.Enter()
2016    if values.get_shape().is_fully_defined():
2017      values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] +
2018                                              values.get_shape().dims[1:])
2019      if self.outer_context:
2020        self.outer_context.Enter()
2021      values_acc = constant_op.constant(
2022          0, values.dtype, shape=values_shape, name="b_acc")
2023      if self.outer_context:
2024        self.outer_context.Exit()
2025    else:
2026      values_shape = _resource_safe_shape(op.inputs[0])[1:]
2027      values_shape = array_ops.concat([[1], values_shape], 0)
2028      values_acc = array_ops.zeros(values_shape, dtype=values.dtype)
2029    indices_acc = constant_op.constant([0], indices.dtype)
2030    shape_acc = None
2031    if dense_shape is not None:
2032      if dense_shape.get_shape().is_fully_defined():
2033        if self.outer_context:
2034          self.outer_context.Enter()
2035        shape_acc = constant_op.constant(
2036            0, dense_shape.dtype, shape=dense_shape.get_shape())
2037        if self.outer_context:
2038          self.outer_context.Exit()
2039      else:
2040        shape_acc = array_ops.zeros_like(
2041            array_ops.shape_internal(
2042                op.inputs[0], optimize=False, out_type=dense_shape.dtype),
2043            optimize=False)
2044
2045    if self.outer_context:
2046      self.outer_context.Exit()
2047
2048    self.Enter()
2049    self.AddName(values_acc.name)
2050    self.AddName(indices_acc.name)
2051    init_acc = [indices_acc, values_acc]
2052    if shape_acc is not None:
2053      self.AddName(shape_acc.name)
2054      init_acc.append(shape_acc)
2055
2056    # Set use_input_shape=False since the accumulator tensors will grow in
2057    # size. If use_input_shape=True, the _update_input call below will result in
2058    # incompatible shapes.
2059    enter_acc = [
2060        _Enter(
2061            x,
2062            self._name,
2063            is_constant=False,
2064            parallel_iterations=self._parallel_iterations,
2065            use_input_shape=False,
2066            name="b_acc") for x in init_acc
2067    ]
2068    # Manually set appropriate partial shapes.
2069    enter_acc[0].set_shape([None])
2070    if values_acc.shape.dims is not None:
2071      enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:])
2072    self.loop_enters.extend(enter_acc)
2073
2074    merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
2075    switch_acc = [switch(x, self._pivot) for x in merge_acc]
2076
2077    # The actual accumulation.
2078    acc_indexed_slices = [
2079        array_ops.concat([xa[1], xv], 0)
2080        for xa, xv in zip(switch_acc[:2], [indices, values])
2081    ]
2082    if shape_acc is not None:
2083      # For the shape we just keep the maximum
2084      acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
2085
2086    next_acc = [_NextIteration(x) for x in acc_indexed_slices]
2087    for xm, xn in zip(merge_acc, next_acc):
2088      xm.op._update_input(1, xn)  # pylint: disable=protected-access
2089
2090    exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]
2091    self.loop_exits.extend(exit_acc)
2092
2093    self.ExitResult(exit_acc)
2094    return indexed_slices.IndexedSlices(
2095        indices=exit_acc[0],
2096        values=exit_acc[1],
2097        dense_shape=exit_acc[2] if shape_acc is not None else None)
2098
2099  def _InitializeValues(self, values):
2100    """Makes the values known to this context."""
2101    self._values = set()
2102    for x in values:
2103      if isinstance(x, ops.Tensor):
2104        self._values.add(x.name)
2105      else:
2106        raise TypeError("'values' must be a list of Tensors. "
2107                        f"Received: {type(x)}.")
2108
2109  def _BuildLoop(self, pred, body, flat_orig_loop_vars, flat_loop_vars,
2110                 loop_vars_signature):
2111    """Core: Add the loop termination condition and body to the graph."""
2112    flat_shape_invariants = nest.map_structure(
2113        lambda spec: spec.shape,
2114        nest.flatten(loop_vars_signature, expand_composites=True))
2115
2116    # Let the context know the loop variables so the loop variables
2117    # would be added in the outer contexts properly.
2118    self._InitializeValues(flat_loop_vars)
2119    if self._outer_context:
2120      real_vars = [self._outer_context.AddValue(x) for x in flat_loop_vars]
2121    else:
2122      real_vars = flat_loop_vars
2123
2124    enter_vars = []
2125    with ops.control_dependencies(None):
2126      for real_var, shape_invariant in zip(real_vars, flat_shape_invariants):
2127        enter_var = _Enter(
2128            real_var,
2129            self._name,
2130            is_constant=False,
2131            parallel_iterations=self._parallel_iterations,
2132            use_input_shape=False)
2133
2134        if _ShapeLessThanOrEqual(real_var.get_shape(), shape_invariant):
2135          enter_var.set_shape(shape_invariant)
2136        else:
2137          raise ValueError(
2138              f"The shape invariant specified for {real_var.name} is not "
2139              "compatible with the initial shape of the loop variable. It "
2140              f"enters the loop with shape {real_var.get_shape()}, but the "
2141              f"specified shape invariant is {shape_invariant}.")
2142
2143        enter_var.graph.prevent_feeding(enter_var)
2144        if self._outer_context:
2145          self._outer_context.AddInnerOp(enter_var.op)
2146        enter_vars.append(enter_var)
2147
2148    # Finds the closest enclosing non-None control pivot.
2149    outer_context = self._outer_context
2150    control_pivot = None
2151    while outer_context is not None and control_pivot is None:
2152      control_pivot = outer_context.GetControlPivot()
2153      # pylint: disable=protected-access
2154      outer_context = outer_context._outer_context
2155      # pylint: enable=protected-access
2156
2157    if control_pivot is not None:
2158      for var in enter_vars:
2159        if util.IsLoopConstantEnter(var.op.inputs[0].op):
2160          # pylint: disable=protected-access
2161          var.op._add_control_input(control_pivot.op)
2162          # pylint: enable=protected-access
2163
2164    # Fix the control inputs and control flow context of these enter ops.
2165    self._FixControlInputsAndContext(enter_vars)
2166    self._InitializeValues(enter_vars)
2167    self._loop_enters = enter_vars
2168
2169    merge_vars = [merge([x, x])[0] for x in enter_vars]
2170    self._pivot_for_pred = merge_vars[0]
2171
2172    merge_vars_with_tensorarrays = nest.map_structure(
2173        _convert_flow_to_tensorarray, flat_orig_loop_vars, merge_vars)
2174    # Build the graph for pred.
2175    packed_vars = nest.pack_sequence_as(
2176        structure=loop_vars_signature,
2177        flat_sequence=merge_vars_with_tensorarrays,
2178        expand_composites=True)
2179    c = ops.convert_to_tensor(pred(*packed_vars))
2180    self._pivot = loop_cond(c, name="LoopCond")
2181    switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
2182
2183    # Build the graph for body.
2184    vars_for_body = [_Identity(x[1]) for x in switch_vars]
2185    self._pivot_for_body = vars_for_body[0]
2186    # Convert TensorArray flow variables inside the context back into
2187    # their associated TensorArrays for calling the body.
2188    vars_for_body_with_tensorarrays = nest.map_structure(
2189        _convert_flow_to_tensorarray, flat_orig_loop_vars, vars_for_body)
2190    packed_vars_for_body = nest.pack_sequence_as(
2191        structure=loop_vars_signature,
2192        flat_sequence=vars_for_body_with_tensorarrays,
2193        expand_composites=True)
2194    pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2195    body_result = body(*packed_vars_for_body)
2196    post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2197    if not nest.is_nested(body_result):
2198      body_result = [body_result]
2199    if len(post_summaries) > len(pre_summaries):
2200      new_summaries = post_summaries[len(pre_summaries):]
2201      summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2202      summary_ref[:] = pre_summaries
2203      with ops.control_dependencies(new_summaries):
2204
2205        def map_fn(x):
2206          # TODO(apassos) figure out how to trigger with tensor arrays as well
2207          if isinstance(x, tensor_array_ops.TensorArray):
2208            return x
2209          return array_ops.identity(x)
2210
2211        body_result = nest.map_structure(
2212            map_fn, body_result, expand_composites=True)
2213
2214    body_result = variable_utils.convert_variables_to_tensors(body_result)
2215    # Compare the structure types of input and output of body.
2216    # For backwards compatibility, the first layer is forced to a list
2217    # during this comparison, because inputs are typically lists and
2218    # outputs of the body are typically tuples.
2219    nest.assert_same_structure(
2220        list(packed_vars_for_body), list(body_result), expand_composites=True)
2221
2222    # Store body_result to keep track of TensorArrays returned by body
2223    original_body_result = body_result
2224    # Convert TensorArrays returned by body into their flow variables
2225    result = nest.map_structure(
2226        _convert_tensorarray_to_flow,
2227        nest.flatten(body_result, expand_composites=True),
2228        expand_composites=True)
2229    result = ops.convert_n_to_tensor_or_composite(result)
2230
2231    # Add NextIteration and the back edges to complete the loop.
2232    if len(merge_vars) != len(result):
2233      raise ValueError("Number of inputs and outputs of 'body' must match "
2234                       f"'loop_vars'. Got {len(merge_vars)} for the number of "
2235                       f"inputs/outputs, and {len(result)} for 'loop_vars'.")
2236    next_vars = []
2237    for m, v in zip(merge_vars, result):
2238      next_vars.append(_AddNextAndBackEdge(m, v))
2239
2240    # Add the exit ops.
2241    exit_vars = [exit(x[0]) for x in switch_vars]
2242    self._loop_exits = exit_vars
2243
2244    # Exit the loop.
2245    self.ExitResult(exit_vars)
2246
2247    return original_body_result, exit_vars
2248
2249  def BuildLoop(self, pred, body, loop_vars, shape_invariants,
2250                return_same_structure):
2251    """Add the loop termination condition and body to the graph."""
2252
2253    # Keep flat_orig_loop_vars to identify which are TensorArrays
2254    flat_orig_loop_vars = nest.flatten(loop_vars, expand_composites=True)
2255
2256    loop_vars = nest.map_structure(
2257        _convert_to_tensor_or_composite_or_tensorarray, loop_vars)
2258    # Convert TensorArrays to their flow variables
2259    flat_loop_vars = nest.map_structure(
2260        _convert_tensorarray_to_flow,
2261        nest.flatten(loop_vars, expand_composites=True))
2262
2263    if shape_invariants is not None:
2264      loop_vars_signature = nest.map_structure(
2265          _shape_invariant_to_type_spec, loop_vars, shape_invariants)
2266    else:
2267      loop_vars_signature = nest.map_structure(
2268          _shape_invariant_to_type_spec, loop_vars)
2269
2270    try:
2271      self.Enter()
2272      # _BuildLoop calls _update_input in several places. _mutation_lock()
2273      # ensures a Session.run call cannot occur between creating and mutating
2274      # new ops.
2275      with ops.get_default_graph()._mutation_lock():  # pylint: disable=protected-access
2276        original_body_result, exit_vars = self._BuildLoop(
2277            pred, body, flat_orig_loop_vars, flat_loop_vars,
2278            loop_vars_signature)
2279    finally:
2280      self.Exit()
2281
2282    flat_result = nest.flatten(original_body_result, expand_composites=True)
2283    # Convert TensorArray flow variables outside the context back into
2284    # their associated TensorArrays for returning to caller.
2285    exit_vars_with_tensorarrays = nest.map_structure(
2286        _convert_flow_to_tensorarray, flat_result, exit_vars)
2287
2288    packed_exit_vars = nest.pack_sequence_as(
2289        structure=original_body_result,
2290        flat_sequence=exit_vars_with_tensorarrays,
2291        expand_composites=True)
2292
2293    if return_same_structure:
2294      return packed_exit_vars
2295    else:
2296      return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
2297
2298  def _FixControlInputsAndContext(self, enters):
2299    graph = ops.get_default_graph()
2300    # pylint: disable=protected-access
2301    for e in enters:
2302      if isinstance(e, ops.Tensor):
2303        xs = [e]
2304      else:
2305        raise TypeError("'enters' must be a list of Tensors. "
2306                        f"Received: {type(e)}.")
2307      for x in xs:
2308        inp_op = x.op.inputs[0].op
2309        control_inputs = graph._control_dependencies_for_inputs([inp_op])
2310        outer_control_inputs = []
2311        for op in control_inputs:
2312          # We need to keep control inputs that are in any ancestor
2313          # ControlFlowContext, and within outer WhileContext.
2314          keep_as_control_input = True
2315          op_ctxt = util.GetOutputContext(op)
2316          outer_ctxt = self.outer_context
2317          outer_while_context = (None if outer_ctxt is None else
2318                                 outer_ctxt.GetWhileContext())
2319          while outer_ctxt != op_ctxt:
2320            if outer_ctxt is None or outer_ctxt == outer_while_context:
2321              keep_as_control_input = False
2322              break
2323            outer_ctxt = outer_ctxt.outer_context
2324          if keep_as_control_input:
2325            outer_control_inputs.append(op)
2326        x.op._set_control_flow_context(self)
2327        x.op._add_control_inputs(outer_control_inputs)
2328        graph._record_op_seen_by_control_dependencies(x.op)
2329    # pylint: enable=protected-access
2330
2331  def IsWhileContext(self):
2332    return True
2333
2334
2335# @TODO(b/133606651) Replace "shape_invariants" with "loop_vars_signature".
2336# pylint: disable=redefined-outer-name
2337@tf_export("while_loop", v1=[])
2338@deprecation.deprecated_arg_values(
2339    None,
2340    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
2341Instead of:
2342results = tf.while_loop(c, b, vars, back_prop=False)
2343Use:
2344results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))""",
2345    warn_once=True,
2346    back_prop=False)
2347def while_loop_v2(cond,
2348                  body,
2349                  loop_vars,
2350                  shape_invariants=None,
2351                  parallel_iterations=10,
2352                  back_prop=True,
2353                  swap_memory=False,
2354                  maximum_iterations=None,
2355                  name=None):
2356  """Repeat `body` while the condition `cond` is true.
2357
2358  `cond` is a callable returning a boolean scalar tensor. `body` is a callable
2359  returning a (possibly nested) tuple, namedtuple or list of tensors of the same
2360  arity (length and structure) and types as `loop_vars`. `loop_vars` is a
2361  (possibly nested) tuple, namedtuple or list of tensors that is passed to both
2362  `cond` and `body`. `cond` and `body` both take as many arguments as there are
2363  `loop_vars`.
2364
2365  In addition to regular Tensors or IndexedSlices, the body may accept and
2366  return TensorArray objects.  The flows of the TensorArray objects will
2367  be appropriately forwarded between loops and during gradient calculations.
2368
2369  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
2370  call to `while_loop`, and not at all during `Session.run()`). `while_loop`
2371  stitches together the graph fragments created during the `cond` and `body`
2372  calls with some additional graph nodes to create the graph flow that
2373  repeats `body` until `cond` returns false.
2374
2375  For correctness, `tf.while_loop()` strictly enforces shape invariants for
2376  the loop variables. A shape invariant is a (possibly partial) shape that
2377  is unchanged across the iterations of the loop. An error will be raised
2378  if the shape of a loop variable after an iteration is determined to be more
2379  general than or incompatible with its shape invariant. For example, a shape
2380  of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
2381  compatible with [11, 17]. By default (if the argument `shape_invariants` is
2382  not specified), it is assumed that the initial shape of each tensor in
2383  `loop_vars` is the same in every iteration. The `shape_invariants` argument
2384  allows the caller to specify a less specific shape invariant for each loop
2385  variable, which is needed if the shape varies between iterations. The
2386  `tf.Tensor.set_shape`
2387  function may also be used in the `body` function to indicate that
2388  the output loop variable has a particular shape. The shape invariant for
2389  SparseTensor and IndexedSlices are treated specially as follows:
2390
2391  a) If a loop variable is a SparseTensor, the shape invariant must be
2392  TensorShape([r]) where r is the rank of the dense tensor represented
2393  by the sparse tensor. It means the shapes of the three tensors of the
2394  SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
2395  is the shape of the SparseTensor.dense_shape property. It must be the shape of
2396  a vector.
2397
2398  b) If a loop variable is an IndexedSlices, the shape invariant must be
2399  a shape invariant of the values tensor of the IndexedSlices. It means
2400  the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
2401  [shape.ndims]).
2402
2403  `while_loop` implements non-strict semantics, enabling multiple iterations
2404  to run in parallel. The maximum number of parallel iterations can be
2405  controlled by `parallel_iterations`, which gives users some control over
2406  memory consumption and execution order. For correct programs, `while_loop`
2407  should return the same result for any parallel_iterations > 0.
2408
2409  For training, TensorFlow stores the tensors that are produced in the
2410  forward inference and are needed in back propagation. These tensors are a
2411  main source of memory consumption and often cause OOM errors when training
2412  on GPUs. When the flag swap_memory is true, we swap out these tensors from
2413  GPU to CPU. This for example allows us to train RNN models with very long
2414  sequences and large batches.
2415
2416  Args:
2417    cond: A callable that represents the termination condition of the loop.
2418    body: A callable that represents the loop body.
2419    loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
2420      `Tensor`, and `TensorArray` objects.
2421    shape_invariants: The shape invariants for the loop variables.
2422    parallel_iterations: The number of iterations allowed to run in parallel. It
2423      must be a positive integer.
2424    back_prop: (optional) Deprecated. False disables support for back
2425      propagation. Prefer using `tf.stop_gradient` instead.
2426    swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
2427    maximum_iterations: Optional maximum number of iterations of the while loop
2428      to run.  If provided, the `cond` output is AND-ed with an additional
2429      condition ensuring the number of iterations executed is no greater than
2430      `maximum_iterations`.
2431    name: Optional name prefix for the returned tensors.
2432
2433  Returns:
2434    The output tensors for the loop variables after the loop. The return value
2435      has the same structure as `loop_vars`.
2436
2437  Raises:
2438    TypeError: if `cond` or `body` is not callable.
2439    ValueError: if `loop_vars` is empty.
2440
2441  Example:
2442
2443  ```python
2444  i = tf.constant(0)
2445  c = lambda i: tf.less(i, 10)
2446  b = lambda i: (tf.add(i, 1), )
2447  r = tf.while_loop(c, b, [i])
2448  ```
2449
2450  Example with nesting and a namedtuple:
2451
2452  ```python
2453  import collections
2454  Pair = collections.namedtuple('Pair', 'j, k')
2455  ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
2456  c = lambda i, p: i < 10
2457  b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
2458  ijk_final = tf.while_loop(c, b, ijk_0)
2459  ```
2460
2461  Example using shape_invariants:
2462
2463  ```python
2464  i0 = tf.constant(0)
2465  m0 = tf.ones([2, 2])
2466  c = lambda i, m: i < 10
2467  b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
2468  tf.while_loop(
2469      c, b, loop_vars=[i0, m0],
2470      shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
2471  ```
2472
2473  Example which demonstrates non-strict semantics: In the following
2474  example, the final value of the counter `i` does not depend on `x`. So
2475  the `while_loop` can increment the counter parallel to updates of `x`.
2476  However, because the loop counter at one loop iteration depends
2477  on the value at the previous iteration, the loop counter itself cannot
2478  be incremented in parallel. Hence if we just want the final value of the
2479  counter (which we print on the line `print(sess.run(i))`), then
2480  `x` will never be incremented, but the counter will be updated on a
2481  single thread. Conversely, if we want the value of the output (which we
2482  print on the line `print(sess.run(out).shape)`), then the counter may be
2483  incremented on its own thread, while `x` can be incremented in
2484  parallel on a separate thread. In the extreme case, it is conceivable
2485  that the thread incrementing the counter runs until completion before
2486  `x` is incremented even a single time. The only thing that can never
2487  happen is that the thread updating `x` can never get ahead of the
2488  counter thread because the thread incrementing `x` depends on the value
2489  of the counter.
2490
2491  ```python
2492  import tensorflow as tf
2493
2494  n = 10000
2495  x = tf.constant(list(range(n)))
2496  c = lambda i, x: i < n
2497  b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
2498  [i], "x:"))
2499  i, out = tf.while_loop(c, b, (0, x))
2500  with tf.compat.v1.Session() as sess:
2501      print(sess.run(i))  # prints [0] ... [9999]
2502
2503      # The following line may increment the counter and x in parallel.
2504      # The counter thread may get ahead of the other thread, but not the
2505      # other way around. So you may see things like
2506      # [9996] x:[9987]
2507      # meaning that the counter thread is on iteration 9996,
2508      # while the other thread is on iteration 9987
2509      print(sess.run(out).shape)
2510  ```
2511
2512  """
2513  return while_loop(
2514      cond=cond,
2515      body=body,
2516      loop_vars=loop_vars,
2517      shape_invariants=shape_invariants,
2518      parallel_iterations=parallel_iterations,
2519      back_prop=back_prop,
2520      swap_memory=swap_memory,
2521      name=name,
2522      maximum_iterations=maximum_iterations,
2523      return_same_structure=True)
2524
2525
2526# pylint: disable=redefined-outer-name
2527@tf_export(v1=["while_loop"])
2528def while_loop(cond,
2529               body,
2530               loop_vars,
2531               shape_invariants=None,
2532               parallel_iterations=10,
2533               back_prop=True,
2534               swap_memory=False,
2535               name=None,
2536               maximum_iterations=None,
2537               return_same_structure=False):
2538  """Repeat `body` while the condition `cond` is true.
2539
2540  `cond` is a callable returning a boolean scalar tensor. `body` is a callable
2541  returning a (possibly nested) tuple, namedtuple or list of tensors of the same
2542  arity (length and structure) and types as `loop_vars`. `loop_vars` is a
2543  (possibly nested) tuple, namedtuple or list of tensors that is passed to both
2544  `cond` and `body`. `cond` and `body` both take as many arguments as there are
2545  `loop_vars`.
2546
2547  In addition to regular Tensors or IndexedSlices, the body may accept and
2548  return TensorArray objects.  The flows of the TensorArray objects will
2549  be appropriately forwarded between loops and during gradient calculations.
2550
2551  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
2552  call to `while_loop`, and not at all during `Session.run()`). `while_loop`
2553  stitches together the graph fragments created during the `cond` and `body`
2554  calls with some additional graph nodes to create the graph flow that
2555  repeats `body` until `cond` returns false.
2556
2557  For correctness, `tf.while_loop()` strictly enforces shape invariants for
2558  the loop variables. A shape invariant is a (possibly partial) shape that
2559  is unchanged across the iterations of the loop. An error will be raised
2560  if the shape of a loop variable after an iteration is determined to be more
2561  general than or incompatible with its shape invariant. For example, a shape
2562  of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
2563  compatible with [11, 17]. By default (if the argument `shape_invariants` is
2564  not specified), it is assumed that the initial shape of each tensor in
2565  `loop_vars` is the same in every iteration. The `shape_invariants` argument
2566  allows the caller to specify a less specific shape invariant for each loop
2567  variable, which is needed if the shape varies between iterations. The
2568  `tf.Tensor.set_shape`
2569  function may also be used in the `body` function to indicate that
2570  the output loop variable has a particular shape. The shape invariant for
2571  SparseTensor and IndexedSlices are treated specially as follows:
2572
2573  a) If a loop variable is a SparseTensor, the shape invariant must be
2574  TensorShape([r]) where r is the rank of the dense tensor represented
2575  by the sparse tensor. It means the shapes of the three tensors of the
2576  SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
2577  is the shape of the SparseTensor.dense_shape property. It must be the shape of
2578  a vector.
2579
2580  b) If a loop variable is an IndexedSlices, the shape invariant must be
2581  a shape invariant of the values tensor of the IndexedSlices. It means
2582  the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
2583  [shape.ndims]).
2584
2585  `while_loop` implements non-strict semantics, enabling multiple iterations
2586  to run in parallel. The maximum number of parallel iterations can be
2587  controlled by `parallel_iterations`, which gives users some control over
2588  memory consumption and execution order. For correct programs, `while_loop`
2589  should return the same result for any parallel_iterations > 0.
2590
2591  For training, TensorFlow stores the tensors that are produced in the
2592  forward inference and are needed in back propagation. These tensors are a
2593  main source of memory consumption and often cause OOM errors when training
2594  on GPUs. When the flag swap_memory is true, we swap out these tensors from
2595  GPU to CPU. This for example allows us to train RNN models with very long
2596  sequences and large batches.
2597
2598  Args:
2599    cond: A callable that represents the termination condition of the loop.
2600    body: A callable that represents the loop body.
2601    loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
2602      `Tensor`, and `TensorArray` objects.
2603    shape_invariants: The shape invariants for the loop variables.
2604    parallel_iterations: The number of iterations allowed to run in parallel. It
2605      must be a positive integer.
2606    back_prop: Whether backprop is enabled for this while loop.
2607    swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
2608    name: Optional name prefix for the returned tensors.
2609    maximum_iterations: Optional maximum number of iterations of the while loop
2610      to run.  If provided, the `cond` output is AND-ed with an additional
2611      condition ensuring the number of iterations executed is no greater than
2612      `maximum_iterations`.
2613    return_same_structure: If True, output has same structure as `loop_vars`. If
2614      eager execution is enabled, this is ignored (and always treated as True).
2615
2616  Returns:
2617    The output tensors for the loop variables after the loop.
2618     If `return_same_structure` is True, the return value has the same
2619     structure as `loop_vars`.
2620     If `return_same_structure` is False, the return value is a Tensor,
2621     TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list
2622     otherwise.
2623
2624  Raises:
2625    TypeError: if `cond` or `body` is not callable.
2626    ValueError: if `loop_vars` is empty.
2627
2628  Example:
2629
2630  ```python
2631  i = tf.constant(0)
2632  c = lambda i: tf.less(i, 10)
2633  b = lambda i: tf.add(i, 1)
2634  r = tf.while_loop(c, b, [i])
2635  ```
2636
2637  Example with nesting and a namedtuple:
2638
2639  ```python
2640  import collections
2641  Pair = collections.namedtuple('Pair', 'j, k')
2642  ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
2643  c = lambda i, p: i < 10
2644  b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
2645  ijk_final = tf.while_loop(c, b, ijk_0)
2646  ```
2647
2648  Example using shape_invariants:
2649
2650  ```python
2651  i0 = tf.constant(0)
2652  m0 = tf.ones([2, 2])
2653  c = lambda i, m: i < 10
2654  b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
2655  tf.while_loop(
2656      c, b, loop_vars=[i0, m0],
2657      shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
2658  ```
2659
2660  Example which demonstrates non-strict semantics: In the following
2661  example, the final value of the counter `i` does not depend on `x`. So
2662  the `while_loop` can increment the counter parallel to updates of `x`.
2663  However, because the loop counter at one loop iteration depends
2664  on the value at the previous iteration, the loop counter itself cannot
2665  be incremented in parallel. Hence if we just want the final value of the
2666  counter (which we print on the line `print(sess.run(i))`), then
2667  `x` will never be incremented, but the counter will be updated on a
2668  single thread. Conversely, if we want the value of the output (which we
2669  print on the line `print(sess.run(out).shape)`), then the counter may be
2670  incremented on its own thread, while `x` can be incremented in
2671  parallel on a separate thread. In the extreme case, it is conceivable
2672  that the thread incrementing the counter runs until completion before
2673  `x` is incremented even a single time. The only thing that can never
2674  happen is that the thread updating `x` can never get ahead of the
2675  counter thread because the thread incrementing `x` depends on the value
2676  of the counter.
2677
2678  ```python
2679  import tensorflow as tf
2680
2681  n = 10000
2682  x = tf.constant(list(range(n)))
2683  c = lambda i, x: i < n
2684  b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
2685  [i], "x:"))
2686  i, out = tf.while_loop(c, b, (0, x))
2687  with tf.compat.v1.Session() as sess:
2688      print(sess.run(i))  # prints [0] ... [9999]
2689
2690      # The following line may increment the counter and x in parallel.
2691      # The counter thread may get ahead of the other thread, but not the
2692      # other way around. So you may see things like
2693      # [9996] x:[9987]
2694      # meaning that the counter thread is on iteration 9996,
2695      # while the other thread is on iteration 9987
2696      print(sess.run(out).shape)
2697  ```
2698
2699  """
2700  if not callable(cond):
2701    raise TypeError("'cond' must be callable.")
2702  if not callable(body):
2703    raise TypeError("'body' must be callable.")
2704  if parallel_iterations < 1:
2705    raise TypeError("'parallel_iterations' must be a positive integer.")
2706
2707  loop_vars = variable_utils.convert_variables_to_tensors(loop_vars)
2708
2709  # Always enable control flow v2 if building a function, regardless of toggle.
2710  executing_eagerly = context.executing_eagerly()
2711  if (util.EnableControlFlowV2(ops.get_default_graph()) and
2712      not executing_eagerly):
2713    return while_v2.while_loop(
2714        cond,
2715        body,
2716        loop_vars,
2717        shape_invariants=shape_invariants,
2718        parallel_iterations=parallel_iterations,
2719        maximum_iterations=maximum_iterations,
2720        name=name,
2721        return_same_structure=return_same_structure,
2722        back_prop=back_prop)
2723
2724  with ops.name_scope(name, "while", loop_vars):
2725    if not loop_vars:
2726      raise ValueError("'loop_vars' must be provided.")
2727    try_to_pack = (len(loop_vars) == 1 and not return_same_structure)
2728    if maximum_iterations is not None:
2729      maximum_iterations = ops.convert_to_tensor(
2730          maximum_iterations, name="maximum_iterations")
2731      if maximum_iterations.shape.ndims != 0:
2732        raise ValueError(
2733            "'maximum_iterations' must be a scalar. "
2734            f"Received shape: {maximum_iterations.shape}")
2735
2736      if executing_eagerly:
2737        counter = 0
2738        maximum_iterations = int(maximum_iterations.numpy())
2739      else:
2740        counter = constant_op.constant(
2741            0, dtype=maximum_iterations.dtype, name="iteration_counter")
2742      orig_cond = cond
2743      orig_body = body
2744      if try_to_pack:
2745        loop_vars = (counter, loop_vars[0])
2746        cond = lambda i, lv: (  # pylint: disable=g-long-lambda
2747            math_ops.logical_and(i < maximum_iterations, orig_cond(lv)))
2748        body = lambda i, lv: (i + 1, orig_body(lv))
2749      else:
2750        loop_vars = (counter, loop_vars)
2751        cond = lambda i, lv: (  # pylint: disable=g-long-lambda
2752            math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
2753        body = lambda i, lv: (i + 1, orig_body(*lv))
2754      try_to_pack = False
2755
2756    if executing_eagerly:
2757      packed = False  # whether the body result was packed into a 1-item tuple
2758
2759      loop_var_structure = nest.map_structure(type_spec.type_spec_from_value,
2760                                              list(loop_vars))
2761      while cond(*loop_vars):
2762        loop_vars = body(*loop_vars)
2763        if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2764          packed = True
2765          loop_vars = (loop_vars,)
2766        nest.assert_same_structure(loop_var_structure, list(loop_vars))
2767
2768      def convert(x):
2769        if isinstance(x, tensor_array_ops.TensorArray):
2770          return x
2771        return ops.convert_to_tensor(x)
2772
2773      loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True)
2774      if maximum_iterations is not None:
2775        return loop_vars[1]
2776      else:
2777        return loop_vars[0] if packed else loop_vars
2778
2779    if shape_invariants is not None:
2780      if maximum_iterations is not None:
2781        shape_invariants = (tensor_shape.TensorShape([]), shape_invariants)
2782
2783    loop_context = WhileContext(
2784        maximum_iterations=maximum_iterations,
2785        parallel_iterations=parallel_iterations,
2786        back_prop=back_prop,
2787        swap_memory=swap_memory)
2788    # Only add non-nested loops to the collection. Any nested control flow will
2789    # be encapsulated in the root context.
2790    if loop_context.outer_context is None:
2791      ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
2792    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
2793                                    return_same_structure)
2794    if maximum_iterations is not None:
2795      return result[1]
2796    else:
2797      return result
2798
2799
2800# pylint: enable=redefined-outer-name
2801
2802
2803def _AsTensorList(x, p):
2804  """Return x as a list of Tensors or IndexedSlices.
2805
2806  For entries of `x` that are Operations, this returns an Identity of `p`
2807  with a dependency on the operation.
2808
2809  Args:
2810    x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
2811    p: A Tensor to return for entries in `x` that are Operations.
2812
2813  Returns:
2814    A list of Tensors or IndexedSlices.
2815  """
2816  if not isinstance(x, (list, _basetuple)):
2817    x = [x]
2818
2819  l = []
2820  for v in x:
2821    if isinstance(v, ops.Operation):
2822      v = with_dependencies([v], p)
2823    v = ops.convert_to_tensor_or_composite(v)
2824    if isinstance(v, ops.Tensor):
2825      l.append(array_ops.identity(v))
2826    else:
2827      l.append(
2828          indexed_slices.IndexedSlices(
2829              array_ops.identity(v.values), array_ops.identity(v.indices)))
2830  return l
2831
2832
2833def _CheckResults(a, b):
2834  assert len(a) == len(b), (
2835      "Values returned by a() and b() must have the same length.")
2836  for x, y in zip(a, b):
2837    assert x.dtype == y.dtype, (
2838        "Values returned by a() [%s] and b() [%s] must have "
2839        "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
2840
2841
2842def with_dependencies(dependencies, output_tensor, name=None):
2843  """Produces the content of `output_tensor` only after `dependencies`.
2844
2845  In some cases, a user may want the output of an operation to be
2846  consumed externally only after some other dependencies have run
2847  first. This function ensures returns `output_tensor`, but only after all
2848  operations in `dependencies` have run. Note that this means that there is
2849  no guarantee that `output_tensor` will be evaluated after any `dependencies`
2850  have run.
2851
2852  See also `tf.tuple` and `tf.group`.
2853
2854  Args:
2855    dependencies: Iterable of operations to run before this op finishes.
2856    output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
2857    name: (Optional) A name for this operation.
2858
2859  Returns:
2860    Same as `output_tensor`.
2861
2862  Raises:
2863    TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
2864  """
2865  if context.executing_eagerly():
2866    return output_tensor
2867  with ops.name_scope(name, "control_dependency",
2868                      list(dependencies) + [output_tensor]) as name:
2869    with ops.colocate_with(output_tensor):
2870      with ops.control_dependencies(dependencies):
2871        output_tensor = ops.convert_to_tensor_or_composite(output_tensor)
2872        if isinstance(output_tensor, indexed_slices.IndexedSlices):
2873          return indexed_slices.IndexedSlices(
2874              _Identity(output_tensor.values, name=name), output_tensor.indices,
2875              output_tensor.dense_shape)
2876        else:
2877          return _Identity(output_tensor, name=name)
2878
2879
2880def _GroupControlDeps(dev, deps, name=None):
2881  with ops.control_dependencies(deps):
2882    if dev is None:
2883      return no_op(name=name)
2884    else:
2885      with ops.device(dev):
2886        return no_op(name=name)
2887
2888
2889# TODO(touts): Accept "inputs" as a list.
2890@tf_export("group")
2891def group(*inputs, **kwargs):
2892  """Create an op that groups multiple operations.
2893
2894  When this op finishes, all ops in `inputs` have finished. This op has no
2895  output.
2896
2897  Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
2898  this method, as ops execute in the expected order thanks to automatic control
2899  dependencies.* Only use `tf.group` when working with v1
2900  `tf.Graph` code.
2901
2902  When operating in a v1-style graph context, ops are not executed in the same
2903  order as specified in the code; TensorFlow will attempt to execute ops in
2904  parallel or in an order convenient to the result it is computing.  `tf.group`
2905  allows you to request that one or more results finish before execution
2906  continues.
2907
2908  `tf.group` creates a single op (of type `NoOp`), and then adds appropriate
2909  control dependencies.  Thus, `c = tf.group(a, b)` will compute the same graph
2910  as this:
2911
2912      with tf.control_dependencies([a, b]):
2913          c = tf.no_op()
2914
2915  See also `tf.tuple` and
2916  `tf.control_dependencies`.
2917
2918  Args:
2919    *inputs: Zero or more tensors to group.
2920    name: A name for this operation (optional).
2921
2922  Returns:
2923    An Operation that executes all its inputs.
2924
2925  Raises:
2926    ValueError: If an unknown keyword argument is provided.
2927  """
2928  if context.executing_eagerly():
2929    return None
2930  name = kwargs.pop("name", None)
2931  if kwargs:
2932    raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
2933  with ops.name_scope(name, "group_deps", inputs) as name:
2934    # Grouping no inputs means do nothing
2935    if not inputs:
2936      return no_op(name=name)
2937
2938    # Sorts *inputs according to their devices.
2939    ops_on_device = {}  # device -> operations specified on the device.
2940    for inp in nest.flatten(inputs, expand_composites=True):
2941      if not hasattr(inp, "device"):
2942        raise TypeError("'inputs' should be zero or more (nested) Tensors. "
2943                        f"Received '{inp}' with type '{type(inp)}'.")
2944      dev = inp.device
2945      if dev in ops_on_device:
2946        ops_on_device[dev].append(inp)
2947      else:
2948        ops_on_device[dev] = [inp]
2949    if len(ops_on_device) == 1:
2950      # 1-level tree. The root node is the returned NoOp node.
2951      (dev, deps), = ops_on_device.items()
2952      return _GroupControlDeps(dev, deps, name=name)
2953
2954    # 2-level tree. The root node is the returned NoOp node.
2955    # deps contains 1 NoOp node for each device.
2956    deps = []
2957
2958    def device_key(dev):
2959      """A sort key that allows None to be compared to strings."""
2960      return "" if dev is None else dev
2961
2962    for dev in sorted(ops_on_device, key=device_key):
2963      deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
2964
2965    with ops.control_dependencies(deps):
2966      return no_op(name=name)
2967
2968
2969@tf_export("tuple", v1=[])
2970@dispatch.add_dispatch_support
2971def tuple_v2(tensors, control_inputs=None, name=None):
2972  """Groups tensors together.
2973
2974  The returned tensors have the same value as the input tensors, but they
2975  are computed only after all the input tensors have been computed.
2976
2977  Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
2978  this method, as ops execute in the expected order thanks to automatic control
2979  dependencies.* Only use `tf.tuple` when working with v1 `tf.Graph` code.
2980
2981  See also `tf.group` and `tf.control_dependencies`.
2982
2983  Args:
2984    tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
2985    control_inputs: List of additional ops to finish before returning.
2986    name: (optional) A name to use as a `name_scope` for the operation.
2987
2988  Returns:
2989    Same as `tensors`.
2990
2991  Raises:
2992    ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
2993    TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
2994      objects.
2995
2996  """
2997  return tuple(tensors=tensors, name=name, control_inputs=control_inputs)  # pylint: disable=redefined-builtin
2998
2999
3000@tf_export(v1=["tuple"])
3001@dispatch.add_dispatch_support
3002def tuple(tensors, name=None, control_inputs=None):  # pylint: disable=redefined-builtin
3003  """Group tensors together.
3004
3005  This creates a tuple of tensors with the same values as the `tensors`
3006  argument, except that the value of each tensor is only returned after the
3007  values of all tensors have been computed.
3008
3009  `control_inputs` contains additional ops that have to finish before this op
3010  finishes, but whose outputs are not returned.
3011
3012  This can be used as a "join" mechanism for parallel computations: all the
3013  argument tensors can be computed in parallel, but the values of any tensor
3014  returned by `tuple` are only available after all the parallel computations
3015  are done.
3016
3017  See also `tf.group` and
3018  `tf.control_dependencies`.
3019
3020  Args:
3021    tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
3022    name: (optional) A name to use as a `name_scope` for the operation.
3023    control_inputs: List of additional ops to finish before returning.
3024
3025  Returns:
3026    Same as `tensors`.
3027
3028  Raises:
3029    ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
3030    TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
3031      objects.
3032
3033  """
3034  if context.executing_eagerly():
3035    return tensors
3036  with ops.name_scope(name, "tuple", tensors) as name:
3037    tensors = [
3038        t if (isinstance(t, ops.Operation) or tensor_util.is_tf_type(t) or
3039              t is None) else ops.convert_to_tensor(t) for t in tensors
3040    ]
3041    gating_ops = [
3042        t if isinstance(t, ops.Operation) else t.op
3043        for t in tensors
3044        if t is not None
3045    ]
3046    if control_inputs:
3047      for c in control_inputs:
3048        if isinstance(c, ops.Tensor):
3049          c = c.op
3050        elif not isinstance(c, ops.Operation):
3051          raise TypeError(
3052              "'control_inputs' must only contain Operation or Tensor. "
3053              f"Received: {type(c)}")
3054        gating_ops.append(c)
3055    # Note that in order to ensure ordering in the pbtxt, we must take care to
3056    # ensure the order here.
3057    gating_ops = sorted(set(gating_ops), key=lambda op: op._id)  # Uniquify ops.
3058    if not gating_ops:
3059      raise ValueError("'tensors' must have at least one Tensor. "
3060                       f"Received: {tensors}.")
3061    gate = group(*gating_ops)
3062    tpl = []
3063    for t in tensors:
3064      if tensor_util.is_tf_type(t):
3065        tpl.append(with_dependencies([gate], t))
3066      elif isinstance(t, ops.Operation):
3067        with ops.control_dependencies([gate]):
3068          tpl.append(group(t))
3069      else:
3070        tpl.append(None)
3071    return tpl
3072
3073
3074def _assert_at_most_n_true(predicates, n, msg):
3075  """Returns an Assert op that checks that at most n predicates are True.
3076
3077  Args:
3078    predicates: list of bool scalar tensors.
3079    n: maximum number of true predicates allowed.
3080    msg: Error message.
3081  """
3082  preds_c = array_ops.stack(predicates, name="preds_c")
3083  num_true_conditions = math_ops.reduce_sum(
3084      math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
3085  condition = math_ops.less_equal(num_true_conditions,
3086                                  constant_op.constant(n, name="n_true_conds"))
3087  preds_names = ", ".join(getattr(p, "name", "?") for p in predicates)
3088  error_msg = [
3089      "%s: more than %d conditions (%s) evaluated as True:" %
3090      (msg, n, preds_names), preds_c
3091  ]
3092  return Assert(condition, data=error_msg, summarize=len(predicates))
3093
3094
3095def _case_create_default_action(predicates, actions):
3096  """Creates default action for a list of actions and their predicates.
3097
3098  It uses the input actions to select an arbitrary as default and makes sure
3099  that corresponding predicates have valid values.
3100
3101  Args:
3102    predicates: a list of bool scalar tensors
3103    actions: a list of callable objects which return tensors.
3104
3105  Returns:
3106    a callable
3107  """
3108  k = len(predicates) - 1  # could pick any
3109  predicate, action = predicates[k], actions[k]
3110  other_predicates, other_actions = predicates[:k], actions[:k]
3111
3112  def default_action():
3113    others_msg = ("Implementation error: "
3114                  "selected default action #%d was called, but some of other "
3115                  "predicates are True: " % k)
3116    default_msg = ("Input error: "
3117                   "None of conditions evaluated as True:",
3118                   array_ops.stack(predicates, name="preds_c"))
3119    with ops.control_dependencies([
3120        _assert_at_most_n_true(other_predicates, n=0, msg=others_msg),
3121        Assert(predicate, data=default_msg)
3122    ]):
3123      return action()
3124
3125  return default_action, other_predicates, other_actions
3126
3127
3128def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name,
3129                                       allow_python_preds):
3130  """Verifies input arguments for the case function.
3131
3132  Args:
3133    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
3134      callable which returns a list of tensors.
3135    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3136    name: A name for the case operation.
3137    allow_python_preds: if true, pred_fn_pairs may contain Python bools in
3138      addition to boolean Tensors
3139
3140  Raises:
3141    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3142    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3143    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3144               callable.
3145
3146  Returns:
3147    a tuple <list of scalar bool tensors, list of callables>.
3148  """
3149  if not isinstance(pred_fn_pairs, (list, _basetuple, dict)):
3150    raise TypeError("'pred_fn_pairs' must be a list, tuple, or dict. "
3151                    f"Received: {type(pred_fn_pairs)}")
3152
3153  if isinstance(pred_fn_pairs, collections.OrderedDict):
3154    pred_fn_pairs = pred_fn_pairs.items()
3155  elif isinstance(pred_fn_pairs, dict):
3156    if context.executing_eagerly():
3157      # No name to sort on in eager mode. Use dictionary traversal order,
3158      # which is nondeterministic in versions of Python < 3.6
3159      if not exclusive:
3160        raise ValueError("Unordered dictionaries are not supported for the "
3161                         "'pred_fn_pairs' argument when `exclusive=False` and "
3162                         "eager mode is enabled.")
3163      pred_fn_pairs = list(pred_fn_pairs.items())
3164    else:
3165      pred_fn_pairs = sorted(
3166          pred_fn_pairs.items(), key=lambda item: item[0].name)
3167      if not exclusive:
3168        logging.warn(
3169            "%s: An unordered dictionary of predicate/fn pairs was "
3170            "provided, but exclusive=False. The order of conditional "
3171            "tests is deterministic but not guaranteed.", name)
3172  for pred_fn_pair in pred_fn_pairs:
3173    if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2:
3174      raise TypeError("Each entry in 'pred_fn_pairs' must be a 2-tuple. "
3175                      f"Received {pred_fn_pair}.")
3176    pred, fn = pred_fn_pair
3177
3178    if isinstance(pred, ops.Tensor):
3179      if pred.dtype != dtypes.bool:
3180        raise TypeError("pred must be Tensor of type bool: %s" % pred.name)
3181    elif not allow_python_preds:
3182      raise TypeError("pred must be a Tensor, got: %s" % pred)
3183    elif not isinstance(pred, bool):
3184      raise TypeError("pred must be a Tensor or bool, got: %s" % pred)
3185
3186    if not callable(fn):
3187      raise TypeError("fn for pred %s must be callable." % pred.name)
3188
3189  predicates, actions = zip(*pred_fn_pairs)
3190  return predicates, actions
3191
3192
3193def _case_helper(cond_fn,
3194                 pred_fn_pairs,
3195                 default,
3196                 exclusive,
3197                 name,
3198                 allow_python_preds=False,
3199                 **cond_kwargs):
3200  """Implementation of case that allows for different cond functions.
3201
3202  Args:
3203    cond_fn: method that has signature and semantics of `cond` above.
3204    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
3205      callable which returns a list of tensors.
3206    default: Optional callable that returns a list of tensors.
3207    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3208    name: A name for this operation (optional).
3209    allow_python_preds: if true, pred_fn_pairs may contain Python bools in
3210      addition to boolean Tensors
3211    **cond_kwargs: keyword arguments that will be passed to `cond_fn`.
3212
3213  Returns:
3214    The tensors returned by the first pair whose predicate evaluated to True, or
3215    those returned by `default` if none does.
3216
3217  Raises:
3218    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3219    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3220    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3221               callable.
3222  """
3223  predicates, actions = _case_verify_and_canonicalize_args(
3224      pred_fn_pairs, exclusive, name, allow_python_preds)
3225  with ops.name_scope(name, "case", [predicates]):
3226    if default is None:
3227      default, predicates, actions = _case_create_default_action(
3228          predicates, actions)
3229    fn = default
3230    # To eval conditions in direct order we create nested conditions in reverse:
3231    #   cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...))
3232    for predicate, action in reversed(list(zip(predicates, actions))):
3233      fn = functools.partial(
3234          cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs)
3235    if exclusive:
3236      with ops.control_dependencies([
3237          _assert_at_most_n_true(
3238              predicates, n=1, msg="Input error: exclusive=True")
3239      ]):
3240        return fn()
3241    else:
3242      return fn()
3243
3244
3245def _indexed_case_verify_and_canonicalize_args(branch_fns, default,
3246                                               branch_index):
3247  """Verifies input arguments for the case function.
3248
3249  Args:
3250    branch_fns: Dict or list of pairs of an `int` and a callable which
3251      returns a list of tensors.
3252    default: Optional callable that returns a list of tensors.
3253    branch_index: Optional int `Tensor`, which selects for the corresponding
3254      pred_fn_pair.
3255
3256  Raises:
3257    TypeError: If `branch_fns` is not a list/dictionary.
3258    TypeError: If `branch_fns` is a list but does not contain 2-tuples or
3259               callables.
3260    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3261               callable.
3262
3263  Returns:
3264    branch_fns: validated list of callables for each branch (default last).
3265  """
3266  if not isinstance(branch_index, ops.Tensor):
3267    raise TypeError("'branch_index' must be a Tensor, got {}".format(
3268        type(branch_index)))
3269  if not branch_index.dtype.is_integer:
3270    raise TypeError("'branch_index' must be an integer Tensor, got {}".format(
3271        branch_index.dtype))
3272
3273  if not branch_fns:
3274    raise ValueError("Must provide at least one item in 'branch_fns'")
3275  if not isinstance(branch_fns, (list, _basetuple, dict)):
3276    raise TypeError("'branch_fns' must be a list, tuple, or dict")
3277
3278  if isinstance(branch_fns, dict):
3279    branch_fns = branch_fns.items()
3280
3281  if all(callable(fn) for fn in branch_fns):
3282    branch_fns = list(enumerate(branch_fns))
3283
3284  for key_fn_pair in branch_fns:
3285    if not isinstance(key_fn_pair, _basetuple) or len(key_fn_pair) != 2:
3286      raise TypeError("Each entry in 'branch_fns' must be a 2-tuple. "
3287                      f"Received {key_fn_pair}.")
3288    key, branch_fn = key_fn_pair
3289
3290    if not isinstance(key, int):
3291      raise TypeError("key must be a Python `int`, got {}".format(type(key)))
3292
3293    if not callable(branch_fn):
3294      raise TypeError("fn for key {} must be callable.".format(key))
3295
3296  keys = [p[0] for p in branch_fns]
3297  if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys):
3298    raise ValueError(
3299        "branch indices (keys) must form contiguous range of [0 to {}) but "
3300        "found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys)))))
3301  actions = [p[1] for p in sorted(branch_fns)]
3302  if default is not None:
3303    actions.append(default)
3304  return actions
3305
3306
3307def _indexed_case_helper(branch_fns,
3308                         default,
3309                         branch_index,
3310                         name,
3311                         lower_using_switch_merge=None):
3312  """Implementation of case that emits the n-way indexed Case op.
3313
3314  Args:
3315    branch_fns: Dict or list of pairs of a boolean scalar tensor, and a
3316      callable which returns a list of tensors.
3317    default: Optional callable that returns a list of tensors.
3318    branch_index: Optional int `Tensor`, which selects for the corresponding
3319      pred_fn_pair.
3320    name: A name for this operation (optional).
3321    lower_using_switch_merge: Lower this op using switch merge ops (optional).
3322
3323  Returns:
3324    The tensors returned by the pair whose key matched branch_index, or
3325    those returned by `default` if none does.
3326
3327  Raises:
3328    TypeError: If `branch_fns` is not a list/dictionary.
3329    TypeError: If `branch_fns` is a list but does not contain 2-tuples or
3330               callables.
3331    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3332               callable.
3333  """
3334  branch_fns = _indexed_case_verify_and_canonicalize_args(
3335      branch_fns, default, branch_index)
3336  with ops.name_scope(name, "case", [branch_index]):
3337    if context.executing_eagerly() and not hasattr(branch_index, "graph"):
3338      branch_index = array_ops.where(
3339          math_ops.less(branch_index, 0)
3340          | math_ops.greater_equal(branch_index, len(branch_fns)),
3341          len(branch_fns) - 1, branch_index)
3342      return branch_fns[int(branch_index)]()
3343    return cond_v2.indexed_case(
3344        branch_index,
3345        branch_fns,
3346        lower_using_switch_merge=lower_using_switch_merge)
3347
3348
3349@tf_export("case", v1=[])
3350@dispatch.add_dispatch_support
3351def case_v2(pred_fn_pairs,
3352            default=None,
3353            exclusive=False,
3354            strict=False,
3355            name="case"):
3356  """Create a case operation.
3357
3358  See also `tf.switch_case`.
3359
3360  The `pred_fn_pairs` parameter is a list of pairs of size N.
3361  Each pair contains a boolean scalar tensor and a python callable that
3362  creates the tensors to be returned if the boolean evaluates to True.
3363  `default` is a callable generating a list of tensors. All the callables
3364  in `pred_fn_pairs` as well as `default` (if provided) should return the same
3365  number and types of tensors.
3366
3367  If `exclusive==True`, all predicates are evaluated, and an exception is
3368  thrown if more than one of the predicates evaluates to `True`.
3369  If `exclusive==False`, execution stops at the first predicate which
3370  evaluates to True, and the tensors generated by the corresponding function
3371  are returned immediately. If none of the predicates evaluate to True, this
3372  operation returns the tensors generated by `default`.
3373
3374  `tf.case` supports nested structures as implemented in
3375  `tf.nest`. All of the callables must return the same (possibly nested) value
3376  structure of lists, tuples, and/or named tuples. Singleton lists and tuples
3377  form the only exceptions to this: when returned by a callable, they are
3378  implicitly unpacked to single values. This behavior is disabled by passing
3379  `strict=True`.
3380
3381  @compatibility(v2)
3382  `pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and
3383  tf.Variable are no longer hashable in v2, so cannot be used as a key for a
3384  dictionary.  Please use a list or a tuple instead.
3385  @end_compatibility
3386
3387
3388  **Example 1:**
3389
3390  Pseudocode:
3391
3392  ```
3393  if (x < y) return 17;
3394  else return 23;
3395  ```
3396
3397  Expressions:
3398
3399  ```python
3400  f1 = lambda: tf.constant(17)
3401  f2 = lambda: tf.constant(23)
3402  r = tf.case([(tf.less(x, y), f1)], default=f2)
3403  ```
3404
3405  **Example 2:**
3406
3407  Pseudocode:
3408
3409  ```
3410  if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
3411  if (x < y) return 17;
3412  else if (x > z) return 23;
3413  else return -1;
3414  ```
3415
3416  Expressions:
3417
3418  ```python
3419  def f1(): return tf.constant(17)
3420  def f2(): return tf.constant(23)
3421  def f3(): return tf.constant(-1)
3422  r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)],
3423           default=f3, exclusive=True)
3424  ```
3425
3426  Args:
3427    pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which
3428      returns a list of tensors.
3429    default: Optional callable that returns a list of tensors.
3430    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3431    strict: A boolean that enables/disables 'strict' mode; see above.
3432    name: A name for this operation (optional).
3433
3434  Returns:
3435    The tensors returned by the first pair whose predicate evaluated to True, or
3436    those returned by `default` if none does.
3437
3438  Raises:
3439    TypeError: If `pred_fn_pairs` is not a list/tuple.
3440    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3441    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3442               callable.
3443  """
3444  return _case_helper(
3445      cond,
3446      pred_fn_pairs,
3447      default,
3448      exclusive,
3449      name,
3450      allow_python_preds=False,
3451      strict=strict)
3452
3453
3454@tf_export(v1=["case"])
3455@dispatch.add_dispatch_support
3456def case(pred_fn_pairs,
3457         default=None,
3458         exclusive=False,
3459         strict=False,
3460         name="case"):
3461  """Create a case operation.
3462
3463  See also `tf.switch_case`.
3464
3465  The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
3466  Each pair contains a boolean scalar tensor and a python callable that
3467  creates the tensors to be returned if the boolean evaluates to True.
3468  `default` is a callable generating a list of tensors. All the callables
3469  in `pred_fn_pairs` as well as `default` (if provided) should return the same
3470  number and types of tensors.
3471
3472  If `exclusive==True`, all predicates are evaluated, and an exception is
3473  thrown if more than one of the predicates evaluates to `True`.
3474  If `exclusive==False`, execution stops at the first predicate which
3475  evaluates to True, and the tensors generated by the corresponding function
3476  are returned immediately. If none of the predicates evaluate to True, this
3477  operation returns the tensors generated by `default`.
3478
3479  `tf.case` supports nested structures as implemented in
3480  `tf.nest`. All of the callables must return the same (possibly nested) value
3481  structure of lists, tuples, and/or named tuples. Singleton lists and tuples
3482  form the only exceptions to this: when returned by a callable, they are
3483  implicitly unpacked to single values. This behavior is disabled by passing
3484  `strict=True`.
3485
3486  If an unordered dictionary is used for `pred_fn_pairs`, the order of the
3487  conditional tests is not guaranteed. However, the order is guaranteed to be
3488  deterministic, so that variables created in conditional branches are created
3489  in fixed order across runs.
3490
3491  @compatibility(eager)
3492  Unordered dictionaries are not supported in eager mode when `exclusive=False`.
3493  Use a list of tuples instead.
3494  @end_compatibility
3495
3496
3497  **Example 1:**
3498
3499  Pseudocode:
3500
3501  ```
3502  if (x < y) return 17;
3503  else return 23;
3504  ```
3505
3506  Expressions:
3507
3508  ```python
3509  f1 = lambda: tf.constant(17)
3510  f2 = lambda: tf.constant(23)
3511  r = tf.case([(tf.less(x, y), f1)], default=f2)
3512  ```
3513
3514  **Example 2:**
3515
3516  Pseudocode:
3517
3518  ```
3519  if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
3520  if (x < y) return 17;
3521  else if (x > z) return 23;
3522  else return -1;
3523  ```
3524
3525  Expressions:
3526
3527  ```python
3528  def f1(): return tf.constant(17)
3529  def f2(): return tf.constant(23)
3530  def f3(): return tf.constant(-1)
3531  r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
3532           default=f3, exclusive=True)
3533  ```
3534
3535  Args:
3536    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
3537      callable which returns a list of tensors.
3538    default: Optional callable that returns a list of tensors.
3539    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3540    strict: A boolean that enables/disables 'strict' mode; see above.
3541    name: A name for this operation (optional).
3542
3543  Returns:
3544    The tensors returned by the first pair whose predicate evaluated to True, or
3545    those returned by `default` if none does.
3546
3547  Raises:
3548    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3549    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3550    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3551               callable.
3552  """
3553  return _case_helper(
3554      cond,
3555      pred_fn_pairs,
3556      default,
3557      exclusive,
3558      name,
3559      allow_python_preds=False,
3560      strict=strict)
3561
3562
3563@tf_export("switch_case")
3564def switch_case(branch_index,
3565                branch_fns,
3566                default=None,
3567                name="switch_case"):
3568  """Create a switch/case operation, i.e. an integer-indexed conditional.
3569
3570  See also `tf.case`.
3571
3572  This op can be substantially more efficient than `tf.case` when exactly one
3573  branch will be selected. `tf.switch_case` is more like a C++ switch/case
3574  statement than `tf.case`, which is more like an if/elif/elif/else chain.
3575
3576  The `branch_fns` parameter is either a dict from `int` to callables, or list
3577  of (`int`, callable) pairs, or simply a list of callables (in which case the
3578  index is implicitly the key). The `branch_index` `Tensor` is used to select an
3579  element in `branch_fns` with matching `int` key, falling back to `default`
3580  if none match, or `max(keys)` if no `default` is provided. The keys must form
3581  a contiguous set from `0` to `len(branch_fns) - 1`.
3582
3583  `tf.switch_case` supports nested structures as implemented in `tf.nest`. All
3584  callables must return the same (possibly nested) value structure of lists,
3585  tuples, and/or named tuples.
3586
3587  **Example:**
3588
3589  Pseudocode:
3590
3591  ```c++
3592  switch (branch_index) {  // c-style switch
3593    case 0: return 17;
3594    case 1: return 31;
3595    default: return -1;
3596  }
3597  ```
3598  or
3599  ```python
3600  branches = {0: lambda: 17, 1: lambda: 31}
3601  branches.get(branch_index, lambda: -1)()
3602  ```
3603
3604  Expressions:
3605
3606  ```python
3607  def f1(): return tf.constant(17)
3608  def f2(): return tf.constant(31)
3609  def f3(): return tf.constant(-1)
3610  r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3)
3611  # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3})
3612  ```
3613
3614  Args:
3615    branch_index: An int Tensor specifying which of `branch_fns` should be
3616      executed.
3617    branch_fns: A `dict` mapping `int`s to callables, or a `list` of
3618      (`int`, callable) pairs, or simply a list of callables (in which case the
3619      index serves as the key). Each callable must return a matching structure
3620      of tensors.
3621    default: Optional callable that returns a structure of tensors.
3622    name: A name for this operation (optional).
3623
3624  Returns:
3625    The tensors returned by the callable identified by `branch_index`, or those
3626    returned by `default` if no key matches and `default` was provided, or those
3627    returned by the max-keyed `branch_fn` if no `default` is provided.
3628
3629  Raises:
3630    TypeError: If `branch_fns` is not a list/dictionary.
3631    TypeError: If `branch_fns` is a list but does not contain 2-tuples or
3632               callables.
3633    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3634               callable.
3635  """
3636  return _indexed_case_helper(branch_fns, default, branch_index, name)
3637
3638
3639@tf_export("__internal__.execute_fn_for_device", v1=[])
3640def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"):
3641  """Executes one of the provided callables based on the device placement.
3642
3643  This API is used when the implementations for high level function depend on
3644  the underlying device placement. It takes a dictionary of device type to
3645  callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of
3646  the device where to run this op matches the key in 'device_branch_fns',
3647  the corresponding callable is executed, falling back to 'default_fn' if none
3648  matches.
3649
3650  **Example:**
3651  ```python
3652  def f1(): return tf.constant(1)
3653  def f2(): return tf.constant(2)
3654  r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1)
3655  ```
3656  'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on
3657  any other device types.
3658
3659
3660  Args:
3661    device_branch_fns: a dictionary of device types to the callables. Each
3662      callable must return a matching structure of tensors.
3663    default_fn: fallback callable when the underlying device does not match any
3664      key in the 'device_branch_fns'.
3665    name: A name for this operation (optional).
3666
3667  Returns:
3668    The tensors returned by the callable identified by device type during
3669    execution, or those returned by 'default_fn' if no key matches.
3670  """
3671  # Always execute the default fn for XLA to avoid complicated graph by case op.
3672  # see more discussions in b/167276293.
3673  is_in_xla = util.GraphOrParentsInXlaContext(ops.get_default_graph())
3674  if is_in_xla:
3675    return default_fn()
3676  device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()}
3677  branch_fns = list(device_branch_fns_upper.values())
3678  devices = list(device_branch_fns_upper.keys())
3679  device_index = gen_functional_ops.device_index(device_names=devices)
3680  return _indexed_case_helper(
3681      branch_fns,
3682      default_fn,
3683      device_index,
3684      name,
3685      lower_using_switch_merge=False)
3686
3687
3688class XLAControlFlowContext(ControlFlowContext):
3689  """Base class for XLA and TPU control flow contexts."""
3690
3691  def __init__(self):
3692    super(XLAControlFlowContext, self).__init__()
3693    self._name = "XLAControlFlowContext"
3694
3695  def to_control_flow_context_def(self, context_def, export_scope=None):
3696    # pylint: disable=useless-super-delegation
3697    # NOTE(slebedev): the method is required by `ControlFlowContext`.
3698    super(XLAControlFlowContext,
3699          self).to_control_flow_context_def(context_def, export_scope)
3700
3701  def IsXLAContext(self):
3702    return True
3703
3704  def AddOp(self, _):
3705    pass
3706
3707  def AddValue(self, x):
3708    return x
3709
3710  def RequiresUniqueFunctionRetracing(self):
3711    """Returns whether the tf.function should be retraced if the context changes.
3712    """
3713    return False
3714
3715
3716@tf_export("__internal__.get_enclosing_xla_context", v1=[])
3717def get_enclosing_xla_context():
3718  """Recursively find and return the XLAControlFlowContext."""
3719  graph = ops.get_default_graph()
3720  while graph is not None:
3721    # pylint: disable=protected-access
3722    context_ = graph._get_control_flow_context()
3723    # pylint: enable=protected-access
3724    while context_ is not None:
3725      if isinstance(context_, XLAControlFlowContext):
3726        return context_
3727      context_ = context_.outer_context
3728    # This may be a FuncGraph due to defuns or v2 control flow. We need to
3729    # find the original graph with the XLAControlFlowContext.
3730    graph = getattr(graph, "outer_graph", None)
3731  return None
3732
3733
3734def from_control_flow_context_def(context_def, import_scope=None):
3735  """Deserializes `context_def` into the appropriate ControlFlowContext.
3736
3737  Args:
3738    context_def: ControlFlowContextDef proto
3739    import_scope: Optional `string`. Name scope to add.
3740
3741  Returns:
3742    A ControlFlowContext subclass
3743  """
3744  if context_def.HasField("cond_ctxt"):
3745    return CondContext.from_proto(
3746        context_def.cond_ctxt, import_scope=import_scope)
3747  if context_def.HasField("while_ctxt"):
3748    return WhileContext.from_proto(
3749        context_def.while_ctxt, import_scope=import_scope)
3750  raise NotImplementedError("Unknown ControlFlowContextDef field: %s" %
3751                            context_def.WhichOneof("ctxt"))
3752
3753
3754ops.register_proto_function(
3755    ops.GraphKeys.COND_CONTEXT,
3756    proto_type=control_flow_pb2.CondContextDef,
3757    to_proto=CondContext.to_proto,
3758    from_proto=CondContext.from_proto)
3759
3760ops.register_proto_function(
3761    ops.GraphKeys.WHILE_CONTEXT,
3762    proto_type=control_flow_pb2.WhileContextDef,
3763    to_proto=WhileContext.to_proto,
3764    from_proto=WhileContext.from_proto)
3765