• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Control Flow Operations.
16
17See the [autograph](https://www.tensorflow.org/guide/autograph) guide.
18"""
19# pylint: disable=g-bad-name
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import abc
25import collections
26import functools
27
28import six
29
30from tensorflow.core.framework import attr_value_pb2
31from tensorflow.core.protobuf import control_flow_pb2
32from tensorflow.python.eager import context
33from tensorflow.python.framework import composite_tensor
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.framework import tensor_spec
40from tensorflow.python.framework import tensor_util
41from tensorflow.python.framework import type_spec
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import control_flow_util as util
44from tensorflow.python.ops import gen_array_ops
45from tensorflow.python.ops import gen_control_flow_ops
46from tensorflow.python.ops import gen_functional_ops
47from tensorflow.python.ops import gen_logging_ops
48from tensorflow.python.ops import gen_math_ops
49from tensorflow.python.ops import math_ops
50from tensorflow.python.ops import tensor_array_ops
51# go/tf-wildcard-import
52# pylint: disable=wildcard-import,undefined-variable
53from tensorflow.python.ops.gen_control_flow_ops import *
54# pylint: enable=wildcard-import
55from tensorflow.python.platform import tf_logging as logging
56from tensorflow.python.util import compat
57from tensorflow.python.util import deprecation
58from tensorflow.python.util import dispatch
59from tensorflow.python.util import nest
60from tensorflow.python.util import tf_should_use
61from tensorflow.python.util.lazy_loader import LazyLoader
62from tensorflow.python.util.tf_export import tf_export
63
64# This is to avoid a circular dependency:
65# cond_v2 -> gradients_util -> control_flow_ops
66cond_v2 = LazyLoader("cond_v2", globals(),
67                     "tensorflow.python.ops.cond_v2")
68
69# This is to avoid circular dependencies:
70# while_v2 -> control_flow_ops
71# while_v2 -> gradients_util -> control_flow_ops
72while_v2 = LazyLoader("while_v2", globals(),
73                      "tensorflow.python.ops.while_v2")
74
75# def_function also uses cond
76def_function = LazyLoader(
77    "def_function", globals(),
78    "tensorflow.python.eager.def_function")
79
80
81# We override the 'tuple' for a control flow op, so we keep python's
82# existing 'tuple' for later use in this module.
83_basetuple = tuple
84
85
86def _summarize_eager(tensor, summarize=None):
87  """Returns a summarized string representation of eager `tensor`.
88
89  Args:
90    tensor: EagerTensor to summarize
91    summarize: Include these many first elements of `array`
92  """
93  # Emulate the behavior of Tensor::SummarizeValue()
94  if summarize is None:
95    summarize = 3
96  elif summarize < 0:
97    summarize = array_ops.size(tensor)
98
99  # reshape((-1,)) is the fastest way to get a flat array view
100  if tensor._rank():  # pylint: disable=protected-access
101    flat = tensor.numpy().reshape((-1,))
102    lst = [str(x) for x in flat[:summarize]]
103    if len(lst) < flat.size:
104      lst.append("...")
105  else:
106    # tensor.numpy() returns a scalar for zero dimensional arrays
107    if gen_math_ops.not_equal(summarize, 0):
108      lst = [str(tensor.numpy())]
109    else:
110      lst = []
111
112  return ", ".join(lst)
113
114
115# pylint: disable=protected-access
116
117
118# Assert and Print are special symbols in python, so we must
119# use an upper-case version of them.
120@tf_export("debugging.Assert", "Assert")
121@dispatch.add_dispatch_support
122@tf_should_use.should_use_result
123def Assert(condition, data, summarize=None, name=None):
124  """Asserts that the given condition is true.
125
126  If `condition` evaluates to false, print the list of tensors in `data`.
127  `summarize` determines how many entries of the tensors to print.
128
129  Args:
130    condition: The condition to evaluate.
131    data: The tensors to print out when condition is false.
132    summarize: Print this many entries of each tensor.
133    name: A name for this operation (optional).
134
135  Returns:
136    assert_op: An `Operation` that, when executed, raises a
137    `tf.errors.InvalidArgumentError` if `condition` is not true.
138    @compatibility(eager)
139    returns None
140    @end_compatibility
141
142  Raises:
143    @compatibility(TF1)
144    When in TF V1 mode (that is, outside `tf.function`) Assert needs a control
145    dependency on the output to ensure the assertion executes:
146
147  ```python
148  # Ensure maximum element of x is smaller or equal to 1
149  assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])
150  with tf.control_dependencies([assert_op]):
151    ... code using x ...
152  ```
153
154    @end_compatibility
155  """
156  if context.executing_eagerly():
157    if not condition:
158      xs = ops.convert_n_to_tensor(data)
159      data_str = [_summarize_eager(x, summarize) for x in xs]
160      raise errors.InvalidArgumentError(
161          node_def=None,
162          op=None,
163          message="Expected '%s' to be true. Summarized data: %s" %
164          (condition, "\n".join(data_str)))
165    return
166
167  with ops.name_scope(name, "Assert", [condition, data]) as name:
168    xs = ops.convert_n_to_tensor(data)
169    if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs):
170      # As a simple heuristic, we assume that string and int32 are
171      # on host to avoid the need to use cond. If it is not case,
172      # we will pay the price copying the tensor to host memory.
173      return gen_logging_ops._assert(condition, data, summarize, name="Assert")
174    else:
175      condition = ops.convert_to_tensor(condition, name="Condition")
176
177      def true_assert():
178        return gen_logging_ops._assert(
179            condition, data, summarize, name="Assert")
180
181      guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
182      if context.executing_eagerly():
183        return
184      return guarded_assert.op
185
186
187def _Identity(tensor, name=None):
188  """Return a tensor with the same shape and contents as the input tensor.
189
190  Args:
191    tensor: A Tensor.
192    name: A name for this operation (optional).
193
194  Returns:
195    A Tensor with the same type and value as the input Tensor.
196  """
197  tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
198  if isinstance(tensor, ops.Tensor):
199    if tensor.dtype._is_ref_dtype:  # pylint: disable=protected-access
200      return gen_array_ops.ref_identity(tensor, name=name)
201    else:
202      return array_ops.identity(tensor, name=name)
203  elif isinstance(tensor, composite_tensor.CompositeTensor):
204    return nest.map_structure(_Identity, tensor, expand_composites=True)
205  else:
206    raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
207                    f"Received: {type(tensor)}.")
208
209
210def _NextIteration(tensor, name=None):
211  tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
212  if isinstance(tensor, ops.Tensor):
213    if tensor.dtype._is_ref_dtype:  # pylint: disable=protected-access
214      return ref_next_iteration(tensor, name=name)
215    else:
216      return next_iteration(tensor, name=name)
217  elif isinstance(tensor, composite_tensor.CompositeTensor):
218    return nest.map_structure(_NextIteration, tensor, expand_composites=True)
219  else:
220    raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
221                    f"Received: {type(tensor)}.")
222
223
224def _Enter(tensor,
225           frame_name,
226           is_constant=False,
227           parallel_iterations=10,
228           use_ref=True,
229           use_input_shape=True,
230           name=None):
231  """Creates or finds a child frame, and makes `tensor` available to it.
232
233  The unique `frame_name` is used by the `Executor` to identify frames. If
234  `is_constant` is true, `tensor` is a constant in the child frame; otherwise
235  it may be changed in the child frame. At most `parallel_iterations`
236  iterations are run in parallel in the child frame.
237
238  Args:
239    tensor: The tensor to be made available to the child frame.
240    frame_name: The name of the child frame.
241    is_constant: If true, the output is constant within the child frame.
242    parallel_iterations: The number of iterations allowed to run in parallel.
243    use_ref: If true, use ref_enter if tensor is of ref type.
244    use_input_shape: If true, set the result's shape based on tensor's shape.
245    name: A name for this operation (optional).
246
247  Returns:
248    The same tensor as `tensor`.
249  """
250  tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
251  if isinstance(tensor, ops.Tensor):
252    if tensor.dtype._is_ref_dtype and use_ref:  # pylint: disable=protected-access
253      result = gen_control_flow_ops.ref_enter(
254          tensor, frame_name, is_constant, parallel_iterations, name=name)
255    else:
256      result = gen_control_flow_ops.enter(
257          tensor, frame_name, is_constant, parallel_iterations, name=name)
258    if use_input_shape:
259      result.set_shape(tensor.get_shape())
260    return result
261  elif isinstance(tensor, composite_tensor.CompositeTensor):
262
263    def enter_component(t):
264      return _Enter(t, frame_name, is_constant, parallel_iterations, use_ref,
265                    use_input_shape)
266
267    return nest.map_structure(enter_component, tensor, expand_composites=True)
268  else:
269    raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
270                    f"Received: {type(tensor)}.")
271
272
273def exit(tensor, name=None):  # pylint: disable=redefined-builtin
274  """Exits the current frame to its parent frame.
275
276  Exit makes its input `tensor` available to the parent frame.
277
278  Args:
279    tensor: The tensor to be made available to the parent frame.
280    name: A name for this operation (optional).
281
282  Returns:
283    The same tensor as `tensor`.
284  """
285  tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
286  if isinstance(tensor, ops.Tensor):
287    if tensor.dtype._is_ref_dtype:  # pylint: disable=protected-access
288      return gen_control_flow_ops.ref_exit(tensor, name)
289    else:
290      return gen_control_flow_ops._exit(tensor, name)
291  elif isinstance(tensor, composite_tensor.CompositeTensor):
292    return nest.map_structure(exit, tensor, expand_composites=True)
293  else:
294    raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
295                    f"Received: {type(tensor)}.")
296
297
298def switch(data, pred, dtype=None, name=None):
299  """Forwards `data` to an output determined by `pred`.
300
301  If `pred` is false, the `data` input is forwarded to the first output.
302  Otherwise, the data goes to the second output.
303
304  This op handles `Tensor`s and `IndexedSlices`.
305
306  Args:
307    data: The tensor to be forwarded to the appropriate output.
308    pred: A scalar that specifies which output port will receive data.
309    dtype: Optional element type for the returned tensor. If missing, the type
310      is inferred from the type of `value`.
311    name: A name for this operation (optional).
312
313  Returns:
314    `(output_false, output_true)`: If `pred` is true, data will be forwarded
315    to `output_true`, otherwise it goes to `output_false`.
316  """
317  with ops.name_scope(name, "Switch", [data, pred]) as name:
318    data = ops.internal_convert_to_tensor_or_composite(
319        data, dtype=dtype, name="data", as_ref=True)
320    pred = ops.convert_to_tensor(pred, name="pred")
321    if isinstance(data, ops.Tensor):
322      return gen_control_flow_ops.switch(data, pred, name=name)
323    else:
324      if not isinstance(data, composite_tensor.CompositeTensor):
325        raise TypeError(
326            "'data' must be a Tensor or CompositeTensor. "
327            f"Received: {type(data)}.")
328      tensors = nest.flatten(data, expand_composites=True)
329      mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors]
330      mapped_f, mapped_t = zip(*mapped)
331      return (nest.pack_sequence_as(data, mapped_f, expand_composites=True),
332              nest.pack_sequence_as(data, mapped_t, expand_composites=True))
333
334
335def _SwitchRefOrTensor(data, pred, name="Switch"):
336  """Forwards `data` to an output determined by `pred`.
337
338  If `pred` is false, the `data` input is forwarded to the first output.
339  Otherwise, the data goes to the second output.
340
341  This op handles `Tensor`s and `IndexedSlices`.
342
343  Args:
344    data: The tensor to be forwarded to the appropriate output.
345    pred: A scalar that specifies which output port will receive data.
346    name: A name for this operation (optional).
347
348  Returns:
349    `(output_false, output_true)`: If `pred` is true, data will be forwarded to
350    `output_true`, otherwise it goes to `output_false`.
351
352  Raises:
353    TypeError: if data is not a Tensor or IndexedSlices
354  """
355  data = ops.convert_to_tensor_or_composite(data, name="data")
356  # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
357  # addresses the following scenario.
358  #
359  # Assume you execute Optimizer.apply_gradients() in a branch of a cond().
360  #
361  # 1. The update op is created inside a `with ops.colocate(var):` block
362  #
363  # 2. Some tensor `data` is captured and a switch is created in a
364  #    `with ops.colocate_with(data):` block.
365  #
366  # with ops.colocate_with(var):
367  #  with ops.colocate_with(data):
368  #    op = ...
369  #
370  # var and data may be pinned to different devices, so we want to ops
371  # created within ops.colocate_with(data) to ignore the existing stack.
372  with ops.colocate_with(data, ignore_existing=True):
373    if isinstance(data, ops.Tensor):
374      if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
375        return ref_switch(data, pred, name=name)
376    return switch(data, pred, name=name)
377
378
379def merge(inputs, name=None):
380  """Returns the value of an available element of `inputs`.
381
382  This op tests each of the tensors in `inputs` in turn to determine if any of
383  them is available. If it finds an available tensor, it returns it and its
384  index in `inputs`.
385
386  It is an error if more than one tensor in `inputs` is available. If no tensor
387  in `inputs` is available, the returned tensor and index are not set.
388
389  This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
390  `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
391  before merging.
392
393  Args:
394    inputs: The input tensors, at most one of which is available.
395    name: A name for this operation (optional).
396
397  Returns:
398    A tuple containing the chosen input tensor and its index in `inputs`.
399
400  Raises:
401    ValueError: If any of the inputs is None, or inputs are IndexedSlices and
402      some but not all have a dense_shape property.
403  """
404  if any(inp is None for inp in inputs):
405    raise ValueError("At least one of the merge inputs is None: %s" % inputs)
406  with ops.name_scope(name, "Merge", inputs) as name:
407    inputs = [
408        ops.internal_convert_to_tensor_or_composite(inp, as_ref=True)
409        for inp in inputs
410    ]
411    if all(isinstance(v, ops.Tensor) for v in inputs):
412      if all(v.dtype._is_ref_dtype for v in inputs):  # pylint: disable=protected-access
413        return gen_control_flow_ops.ref_merge(inputs, name)
414      else:
415        return gen_control_flow_ops.merge(inputs, name)
416    else:
417      # If there is a mix of tensors and indexed slices, then convert the
418      # tensors to indexed slices.
419      if all(isinstance(v, (ops.IndexedSlices, ops.Tensor)) for v in inputs):
420        inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
421
422      for v in inputs:
423        if not isinstance(v, composite_tensor.CompositeTensor):
424          raise TypeError("Type %s not supported" % type(v))
425
426      for v in inputs[1:]:
427        nest.assert_same_structure(inputs[0], v, expand_composites=True)
428
429      flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs]
430      merged_results = [
431          gen_control_flow_ops.merge(component)
432          for component in zip(*flat_inputs)
433      ]
434      flat_merged = [tensor for (tensor, _) in merged_results]
435      chosen_index = merged_results[0][1]
436      merged_inputs = nest.pack_sequence_as(
437          inputs[0], flat_merged, expand_composites=True)
438      return (merged_inputs, chosen_index)
439
440
441# pylint: enable=protected-access
442
443
444def _convert_tensorarray_to_flow(tensor_or_tensor_array):
445  if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
446    return tensor_or_tensor_array.flow
447  else:
448    return tensor_or_tensor_array
449
450
451def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
452  if len(tensors_or_tensorarrays) != len(tensors_or_flows):
453    raise ValueError(
454        "Lengths of original Tensor list and new list do not match: %d vs. %d" %
455        (len(tensors_or_tensorarrays), len(tensors_or_flows)))
456  return [
457      tensor_array_ops.build_ta_with_new_flow(ta, t_or_flow) if isinstance(
458          ta, tensor_array_ops.TensorArray) else t_or_flow
459      for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
460  ]
461
462
463def _ShapeLessThanOrEqual(shape1, shape2):
464  if shape2.dims is None:
465    return True
466  if shape1.ndims != shape2.ndims:
467    return False
468  for dim1, dim2 in zip(shape1.dims, shape2.dims):
469    if dim2.value is not None and dim1.value != dim2.value:
470      return False
471  return True
472
473
474def _get_shape_invariant(var, shape=None):
475  """Returns shape invariant(s) for the given variable.
476
477  Args:
478    var: The tensor whose shape is described.
479    shape: The shape invariant for the tensor.  If not specified, then a default
480      shape invariant for `var` is returned.
481
482  Returns:
483    `TensorShape` or `list` of `TensorShape`: The shape invariant for `var` (if
484    it is a `Tensor`), or the shape invariants for the components that comprise
485    `var` (if it is a `CompositeTensor`).
486  """
487  if isinstance(var, composite_tensor.CompositeTensor):
488    # Get a TypeSpec for `var`.
489    if shape is None:
490      spec = var._type_spec  # pylint: disable=protected-access
491    else:
492      spec = _shape_invariant_to_type_spec(var, shape)
493
494    tensor_specs = nest.flatten(spec, expand_composites=True)
495    return [tspec.shape for tspec in tensor_specs]
496
497  elif shape is None:
498    return var.shape
499  elif isinstance(shape, tensor_spec.TensorSpec):
500    if var.dtype != shape.dtype:
501      raise TypeError("TensorSpec %r is not compatible with %r" % (shape, var))
502    return shape.shape
503  elif isinstance(shape, type_spec.TypeSpec):
504    raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
505  else:
506    return shape
507
508
509def _shape_invariant_to_type_spec(var, shape):
510  """Converts a shape invariant to a TypeSpec.
511
512  Args:
513    var: The tensor whose shape is described by the shape invariant.
514    shape: A `TypeSpec` or `TensorShape`.  If `shape` is already a `TypeSpec`,
515      then it is simply returned as-is.
516
517  Returns:
518    A `TypeSpec` for `var`, consistent with the given shape.
519  """
520  if shape is None:
521    return type_spec.type_spec_from_value(var)
522  elif isinstance(shape, type_spec.TypeSpec):
523    if not shape.is_compatible_with(var):
524      raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
525    return shape
526  elif not isinstance(shape, tensor_shape.TensorShape):
527    raise TypeError(
528        "'shape' must be one of TypeSpec, TensorShape or None. "
529        f"Received: {type(shape)}")
530
531  if isinstance(var, ops.Tensor):
532    return tensor_spec.TensorSpec(shape, var.dtype)
533
534  elif isinstance(var, composite_tensor.CompositeTensor):
535    try:
536      return var._shape_invariant_to_type_spec(shape)  # pylint: disable=protected-access
537    except NotImplementedError:
538      raise TypeError(
539          "To describe or constrain a %s, use a %s instead of a TensorShape." %
540          (type(var).__name__, type(var._type_spec).__name__))  # pylint: disable=protected-access
541
542  else:
543    raise TypeError("Expected var to be a Tensor or CompositeTensor, got %s"
544                    % var)
545
546
547def _SetShapeInvariants(input_vars, enter_vars, shapes):
548  """Set the shapes of the tensors in `enter_vars` to `shapes`.
549
550  Args:
551    input_vars: A list of tensors that are inputs to `enter_vars`.
552    enter_vars: A list of tensors whose shapes will be set.
553    shapes: A (possibly nested) list of shapes.
554
555  Raises:
556    ValueError: If any tensor in `enter_vars` has a less specific shape
557      than its corresponding shape in `shapes`.
558  """
559  if shapes is None:
560    return
561  flat_shapes = nest.flatten(shapes)
562  if not all(isinstance(s, tensor_shape.TensorShape) for s in flat_shapes):
563    raise ValueError("'shapes' must be a (possibly nested) list of "
564                     "TensorShapes.")
565  # Check that the shapes of the inputs are less than the shape invariants,
566  # and set the shapes of `enter_vars` to the shape invariants.
567  for inp, var, shape in zip(input_vars, enter_vars, flat_shapes):
568    if isinstance(var, ops.Tensor):
569      if not _ShapeLessThanOrEqual(inp.get_shape(), shape):
570        raise ValueError(
571            "The shape invariant specified for %s is not compatible with "
572            "the initial shape of the loop variable. It enters the loop "
573            "with shape %s, but the specified shape invariant is %s." %
574            (inp.name, inp.get_shape(), shape))
575      var.set_shape(shape)
576    else:
577      raise TypeError("'enter_vars' must be a list of Tensors."
578                      f"Received: {type(var)}.")
579
580
581def _EnforceShapeInvariant(merge_var, next_var):
582  """Check if the shapes of the loops variables are invariants.
583
584  Args:
585    merge_var: The tensor representing the initial values of the loop
586      variables.
587    next_var: The tensor representing the values of the loop variables
588      after one loop iteration.
589
590  Raises:
591    ValueError: If any tensor in `merge_var` has a more specific shape than
592      its corresponding tensor in `next_var`.
593  """
594  if isinstance(merge_var, ops.Tensor):
595    m_shape = merge_var.get_shape()
596    n_shape = next_var.get_shape()
597    if not _ShapeLessThanOrEqual(n_shape, m_shape):
598      enter = merge_var.op.inputs[0].op
599      assert util.IsLoopEnter(enter)
600      input_t = enter.inputs[0]
601      raise ValueError(
602          "Input tensor '%s' enters the loop with shape %s, but has shape %s "
603          "after one iteration. To allow the shape to vary across iterations, "
604          "use the `shape_invariants` argument of tf.while_loop to specify a "
605          "less-specific shape." % (input_t.name, input_t.shape, n_shape))
606  else:
607    raise TypeError("'merge_var' must be a Tensor. "
608                    f"Received: {type(merge_var)}.")
609
610
611def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
612  """Add NextIteration and back edge from v to m."""
613  if isinstance(m, ops.Tensor):
614    v = ops.convert_to_tensor(v)
615    v = _NextIteration(v)
616    if enforce_shape_invariant:
617      # Make sure the shapes of loop outputs are correct. We do this before
618      # calling _update_input, which will raise a less-helpful error message if
619      # the types don't match.
620      # TODO(skyewm): call this for other cases below (needs testing)
621      _EnforceShapeInvariant(m, v)
622    m.op._update_input(1, v)  # pylint: disable=protected-access
623  elif isinstance(m, composite_tensor.CompositeTensor):
624    # pylint: disable=protected-access
625    def update_component(m_component, v_component):
626      m_component.op._update_input(1, v_component)
627
628    if isinstance(m, ops.IndexedSlices):
629      v = math_ops._as_indexed_slices(v, optimize=False)
630    # pylint: enable=protected-access
631    v = _NextIteration(v)
632    return nest.map_structure(update_component, m, v, expand_composites=True)
633  else:
634    raise TypeError("'m' must be a Tensor or CompositeTensor. "
635                    f"Received: {type(m)}.")
636  return v
637
638
639@six.add_metaclass(abc.ABCMeta)
640class ControlFlowContext(object):
641  """The base class for control flow context.
642
643  The usage pattern is a sequence of (Enter, Exit) followed by a final
644  ExitResult.
645
646  We maintain the following state for control flow contexts during graph
647  construction:
648   1. graph has _control_flow_context: the current context used to
649      construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()
650   2. op has _control_flow_context: the context to which the op belongs.
651      Set at the time the op is created. Immutable.
652   3. A ControlFlowContext has _outer_context: the context in which this
653      context is created. Set at the time a context is created. Immutable.
654   4. A ControlFlowContext has _context_stack.
655      Pushed and popped by ctxt.Enter() and ctxt.Exit()
656  """
657
658  def __init__(self, values_def=None, import_scope=None):
659    self._nested_contexts = []
660    self._outer_context = ops.get_default_graph()._get_control_flow_context()
661    if self._outer_context:
662      self._outer_context._nested_contexts.append(self)  # pylint: disable=protected-access
663    self._context_stack = []
664    if values_def:
665      self._init_values_from_proto(values_def, import_scope=import_scope)
666    else:
667      # The names of tensors that have been already seen in this context.
668      self._values = set()
669      # The keys are the names of tensors referenced by but external to this
670      # context. Each value is the Tensor that should be used by this context to
671      # access the key value (e.g. a switch output guarding a cond input value).
672      self._external_values = {}
673
674  def _init_values_from_proto(self, values_def, import_scope=None):
675    """Initializes values and external_values from `ValuesDef` protocol buffer.
676
677    Args:
678      values_def: `ValuesDef` protocol buffer.
679      import_scope: Optional `string`. Name scope to add.
680    """
681    assert isinstance(values_def, control_flow_pb2.ValuesDef)
682    self._values = set(
683        ops.prepend_name_scope(value, import_scope)
684        for value in values_def.values)
685    g = ops.get_default_graph()
686    self._external_values = {}
687    for k, v in values_def.external_values.items():
688      k = ops.prepend_name_scope(k, import_scope)
689      self._external_values[k] = g.as_graph_element(
690          ops.prepend_name_scope(v, import_scope))
691    op_names = set([
692        op.split(":")[0]
693        for op in self._values - set(self._external_values.keys())
694    ])
695    for op in op_names:
696      # pylint: disable=protected-access
697      g.as_graph_element(op)._set_control_flow_context(self)
698      # pylint: enable=protected-access
699
700  @property
701  def name(self):
702    return self._name
703
704  @property
705  def outer_context(self):
706    """Return the context containing this context."""
707    return self._outer_context
708
709  @property
710  def grad_state(self):
711    raise NotImplementedError("Abstract method")
712
713  @property
714  def back_prop(self):
715    raise NotImplementedError("Abstract method")
716
717  @abc.abstractmethod
718  def to_control_flow_context_def(self, context_def, export_scope=None):
719    """Serializes this into `context_def`.
720
721    Args:
722      context_def: a `ControlFlowContextDef` protocol buffer.
723      export_scope: Optional `string`. Name scope to remove.
724    """
725    raise NotImplementedError("Abstract method")
726
727  def _to_values_def(self, export_scope=None):
728    """Converts the values to a `ValuesDef` protocol buffer.
729
730    Args:
731      export_scope: Optional `string`. Name scope to remove.
732
733    Returns:
734      A `ValuesDef` protocol buffer.
735    """
736    values_def = control_flow_pb2.ValuesDef()
737    values_def.values.extend(
738        [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
739    for k, v in self._external_values.items():
740      k = ops.strip_name_scope(k, export_scope)
741      values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
742    return values_def
743
744  def AddName(self, name):
745    self._values.add(name)
746
747  # pylint: disable=protected-access
748  def Enter(self):
749    """Enter this control flow context."""
750    graph = ops.get_default_graph()
751    self._context_stack.append(graph._get_control_flow_context())
752    graph._set_control_flow_context(self)
753
754  def Exit(self):
755    """Exit this control flow context."""
756    graph = ops.get_default_graph()
757    last_context = self._context_stack.pop()
758    graph._set_control_flow_context(last_context)
759
760  def EnterGradientColocation(self, op, gradient_uid):
761    """Start building a gradient colocated with an op."""
762    if self._outer_context:
763      self._outer_context.EnterGradientColocation(op, gradient_uid)
764
765  def ExitGradientColocation(self, op, gradient_uid):
766    """Start building a gradient colocated with an op."""
767    if self._outer_context:
768      self._outer_context.ExitGradientColocation(op, gradient_uid)
769
770  def ExitResult(self, result):
771    """Make a list of tensors available in the outer context."""
772    if self._outer_context:
773      def fn(x):
774        self._outer_context.AddName(x.name)
775        return x
776      nest.map_structure(fn, result, expand_composites=True)
777
778  def GetWhileContext(self):
779    """Return the while context containing this context."""
780    if self._outer_context:
781      return self._outer_context.GetWhileContext()
782    return None
783
784  def _RemoveExternalControlEdges(self, op):
785    """Remove any external control dependency on this op."""
786    while_ctxt = self.GetWhileContext()
787    # A control input of `op` is internal if it is in the same while
788    # loop context as the enclosing while loop context of self.
789    if while_ctxt is None:
790      internal_control_inputs = op.control_inputs
791    else:
792      internal_control_inputs = []
793      for x in op.control_inputs:
794        ctxt = util.GetOutputContext(x)
795        if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
796          internal_control_inputs.append(x)
797    external_control_inputs = []
798    if len(internal_control_inputs) != len(op.control_inputs):
799      external_control_inputs = list(
800          set(op.control_inputs) - set(internal_control_inputs))
801      op._remove_all_control_inputs()
802      op._add_control_inputs(internal_control_inputs)
803    return internal_control_inputs, external_control_inputs
804
805  # pylint: enable=protected-access
806
807  def AddInnerOp(self, op):
808    """Notifies a scope about an operator added to an inner scope."""
809    if self._outer_context:
810      self._outer_context.AddInnerOp(op)
811
812  def GetControlPivot(self):
813    """Returns the pivot node for this context, or None."""
814    return None
815
816  def IsWhileContext(self):
817    return False
818
819  def IsCondContext(self):
820    return False
821
822  def IsXLAContext(self):
823    return False
824
825  def __str__(self):
826    return self.name
827
828
829class CondContext(ControlFlowContext):
830  """The context for the conditional construct."""
831
832  def __init__(self,
833               pred=None,
834               pivot=None,
835               branch=None,
836               name="cond_text",
837               context_def=None,
838               import_scope=None):
839    """Creates a `CondContext`.
840
841    Args:
842      pred: The `boolean` tensor for the conditional predicate.
843      pivot: The predicate tensor in this branch.
844      branch: 0 or 1 representing this branch.
845      name: Name of the `CondContext` python object.
846      context_def: Optional `ContextDef` protocol buffer to initialize the
847        `CondContext` object from.
848      import_scope: Optional `string`. Name scope to add. Only used when
849        initialing from protocol buffer.
850    """
851    self._name = ops.get_default_graph().unique_name(name)
852
853    if context_def:
854      self._init_from_proto(context_def, import_scope=import_scope)
855    else:
856      # Initializes the default fields.
857      ControlFlowContext.__init__(self)
858      self._pred = pred  # The boolean tensor for the cond predicate
859      self._pivot = pivot  # The predicate tensor in this branch
860      self._branch = branch  # 0 or 1 representing this branch
861
862      # Values considered to have been already seen in this context. pred is not
863      # included in this context.
864      self._values.add(pred.name)
865      self._external_values[pred.name] = pred
866      self._values.add(pivot.name)
867      pivot.op._set_control_flow_context(self)  # pylint: disable=protected-access
868
869  def _init_from_proto(self, context_def, import_scope=None):
870    """Creates a new `CondContext` from protocol buffer.
871
872    Args:
873      context_def: `CondContextDef` protocol buffer.
874      import_scope: Optional `string`. Name scope to add.
875    """
876    assert isinstance(context_def, control_flow_pb2.CondContextDef)
877    # Create from context_def.
878    g = ops.get_default_graph()
879    self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
880    self._pred = g.as_graph_element(
881        ops.prepend_name_scope(context_def.pred_name, import_scope))
882    self._pivot = g.as_graph_element(
883        ops.prepend_name_scope(context_def.pivot_name, import_scope))
884    self._branch = context_def.branch
885    super(CondContext, self).__init__(
886        values_def=context_def.values_def, import_scope=import_scope)
887
888  @property
889  def pred(self):
890    return self._pred
891
892  @property
893  def pivot(self):
894    return self._pivot
895
896  @property
897  def branch(self):
898    return self._branch
899
900  @property
901  def grad_state(self):
902    if self.GetWhileContext():
903      return self.GetWhileContext().grad_state
904    return None
905
906  @property
907  def back_prop(self):
908    if self.GetWhileContext():
909      self.GetWhileContext().back_prop
910    return False
911
912  def GetControlPivot(self):
913    return self._pivot
914
915  def to_proto(self, export_scope=None):
916    """Converts a `CondContext` to a `CondContextDef` protocol buffer.
917
918    Args:
919      export_scope: Optional `string`. Name scope to remove.
920
921    Returns:
922      A `CondContextDef` protocol buffer.
923    """
924    if (export_scope is None or self.name.startswith(export_scope)):
925      context_def = control_flow_pb2.CondContextDef()
926      context_def.context_name = ops.strip_name_scope(self.name, export_scope)
927      context_def.pred_name = ops.strip_name_scope(self._pred.name,
928                                                   export_scope)
929      context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
930                                                    export_scope)
931      context_def.branch = self._branch
932      context_def.values_def.MergeFrom(
933          super(CondContext, self)._to_values_def(export_scope))
934      for nested in self._nested_contexts:
935        nested_def = context_def.nested_contexts.add()
936        nested.to_control_flow_context_def(nested_def)
937
938      return context_def
939    else:
940      return None
941
942  @staticmethod
943  def from_proto(context_def, import_scope=None):
944    """Returns a `CondContext` object created from `context_def`."""
945    ret = CondContext(context_def=context_def, import_scope=import_scope)
946
947    ret.Enter()
948    for nested_def in context_def.nested_contexts:
949      from_control_flow_context_def(nested_def, import_scope=import_scope)
950    ret.Exit()
951    return ret
952
953  def to_control_flow_context_def(self, context_def, export_scope=None):
954    context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
955
956  def AddValue(self, val):
957    """Add `val` to the current context and its outer context recursively."""
958    if val.name in self._values:
959      # Use the real value if it comes from outer context. This is needed in
960      # particular for nested conds.
961      result = self._external_values.get(val.name)
962      result = val if result is None else result
963    else:
964      result = val
965      self._values.add(val.name)
966      if self._outer_context:
967        result = self._outer_context.AddValue(val)
968        self._values.add(result.name)
969        self._external_values[result.name] = result
970      with ops.control_dependencies(None):
971        result = _SwitchRefOrTensor(result, self._pred)[self._branch]
972        if self._outer_context:
973          self._outer_context.AddInnerOp(result.op)
974
975      result.op.graph.prevent_fetching(result.op)
976      # pylint: disable=protected-access
977      result.op._set_control_flow_context(self)
978      # pylint: enable=protected-access
979
980      # Mark Switch output as seen by this context and any outer contexts,
981      # just like what we do for normal op outputs in _AddOpInternal() below.
982      ctxt = self
983      while ctxt is not None:
984        # pylint: disable=protected-access
985        ctxt._values.add(result.name)
986        ctxt = ctxt._outer_context
987        # pylint: enable=protected-access
988
989      self._external_values[val.name] = result
990    return result
991
992  def AddOp(self, op):
993    self._AddOpInternal(op)
994
995  def _AddOpInternal(self, op):
996    """Add `op` to the current context."""
997    if not op.inputs:
998      # If we're in a while loop, remove any control inputs from outside the
999      # loop.
1000      self._RemoveExternalControlEdges(op)
1001
1002      if not any(
1003          util.OpInContext(input_op, self) for input_op in op.control_inputs):
1004        # pylint: disable=protected-access
1005        op._add_control_input(self._pivot.op)
1006        # pylint: enable=protected-access
1007    else:
1008      # Make each input to 'op' available in this CondContext. If an input is
1009      # already part of this context there's nothing to do, but if it's
1010      # external, AddValue() will handle adding the appropriate Switch node and
1011      # other bookkeeping.
1012      for index in range(len(op.inputs)):
1013        x = op.inputs[index]
1014        if op.type == "Merge" and x.op.type == "NextIteration":
1015          # Edge case: if we're importing a while loop inside this CondContext,
1016          # AddValue() will not correctly handle the NextIteration inputs to
1017          # Merge node. The problem is that the NextIteration should also be
1018          # part of this context, but if we're importing it won't have been
1019          # processed and added to the context yet, so AddValue() will try to
1020          # add a Switch which results in an invalid graph. Instead, we use the
1021          # NextIteration input as-is here, and it will eventually be added to
1022          # the context via AddOp().
1023          real_x = x
1024        else:
1025          real_x = self.AddValue(x)
1026        if real_x != x:
1027          # pylint: disable=protected-access
1028          op._update_input(index, real_x)
1029          # pylint: enable=protected-access
1030      # Remove any external control dependency on this op.
1031      self._RemoveExternalControlEdges(op)
1032      # pylint: disable=protected-access
1033      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
1034        op._add_control_input(self._pivot.op)
1035      # pylint: enable=protected-access
1036
1037    # Mark op's outputs as seen by this context and any outer contexts.
1038    output_names = [x.name for x in op.outputs]
1039    ctxt = self
1040    while ctxt is not None:
1041      # pylint: disable=protected-access
1042      ctxt._values.update(output_names)
1043      ctxt = ctxt._outer_context
1044      # pylint: enable=protected-access
1045
1046    if self._outer_context or not util.IsLoopExit(op):
1047      op.graph.prevent_fetching(op)
1048
1049    if self._outer_context:
1050      self._outer_context.AddInnerOp(op)
1051
1052  def _ProcessOutputTensor(self, val):
1053    """Process an output tensor of a conditional branch."""
1054    real_val = val
1055    if val.name not in self._values:
1056      # Handle the special case of lambda: x
1057      self._values.add(val.name)
1058      if self._outer_context:
1059        real_val = self._outer_context.AddValue(val)
1060        self._values.add(real_val.name)
1061        self._external_values[real_val.name] = real_val
1062      real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
1063      self._external_values[val.name] = real_val
1064    else:
1065      external_val = self._external_values.get(val.name)
1066      if external_val is not None:
1067        real_val = external_val
1068    return real_val
1069
1070  def _BuildCondTensor(self, v):
1071    if isinstance(v, ops.Operation):
1072      # Use pivot as the proxy for this op.
1073      return with_dependencies([v], self._pivot)
1074    else:
1075      v = nest.map_structure(
1076          _convert_tensorarray_to_flow, v, expand_composites=True)
1077      return self._ProcessOutputTensor(ops.convert_to_tensor(v))
1078
1079  def BuildCondBranch(self, fn):
1080    """Add the subgraph defined by fn() to the graph."""
1081    pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1082    original_result = fn()
1083    post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1084    if len(post_summaries) > len(pre_summaries):
1085      new_summaries = post_summaries[len(pre_summaries):]
1086      summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1087      summary_ref[:] = pre_summaries
1088      with ops.control_dependencies(new_summaries):
1089        if original_result is None:
1090          return no_op(), None
1091        elif not isinstance(original_result, ops.Operation):
1092          original_result = nest.map_structure(
1093              array_ops.identity, original_result, expand_composites=True)
1094    if original_result is None:
1095      return None, None
1096
1097    result = nest.map_structure(
1098        self._BuildCondTensor, original_result, expand_composites=True)
1099    if not isinstance(result, (list, _basetuple)):
1100      result = [result]
1101    return original_result, result
1102
1103  def IsCondContext(self):
1104    return True
1105
1106
1107def _UnpackIfSingleton(res):
1108  if isinstance(res, (list, _basetuple)) and len(res) == 1:
1109    return res[0]
1110  else:
1111    return res
1112
1113
1114def _eager_cond_implementation(pred, true_fn, false_fn, strict, name):
1115  """Special cases for `cond` when executing eagerly."""
1116  pred = ops.convert_to_tensor(pred)
1117  pred_constant_value = tensor_util.constant_value(pred)
1118  if pred_constant_value is None:
1119    # Eager tensors from a parallel device may not have a constant
1120    # value. Running the cond op itself would work, but we don't have logic to
1121    # build cond ops without wrapping in a function first.
1122    if (not isinstance(true_fn, def_function.Function)
1123        or not isinstance(false_fn, def_function.Function)):
1124      raise TypeError("When running tf.cond on a parallel device, 'true_fn' "
1125                      "and 'false_fn' must be decorated with `tf.function`.")
1126    @def_function.function
1127    def _parallel_device_cond_wrapper():
1128      return cond_v2.cond_v2(pred, true_fn, false_fn, name)
1129    functions_run_eagerly = def_function.functions_run_eagerly()
1130    if functions_run_eagerly:
1131      # We need to use tf.function to deal with variable creation inside the
1132      # cond, and skipping it because of run_functions_eagerly would just
1133      # crash immediately.
1134      logging.warning(
1135          "It looks like tf.function behavior was disabled, perhaps using "
1136          "tf.config.run_functions_eagerly. Parallelized tf.cond requires "
1137          "tf.function to work. This primitive will override the disable.")
1138    def_function.run_functions_eagerly(False)
1139    try:
1140      return _parallel_device_cond_wrapper()
1141    finally:
1142      if functions_run_eagerly is not None:
1143        def_function.run_functions_eagerly(functions_run_eagerly)
1144  else:
1145    # For conditions which are eager tensors with a constant value (most of
1146    # them), we only call the relevant branch function and execute it eagerly.
1147    with ops.name_scope(name, "cond", [pred]):
1148      if pred_constant_value:
1149        result = true_fn()
1150      else:
1151        result = false_fn()
1152      if not strict:
1153        result = _UnpackIfSingleton(result)
1154      return result
1155
1156
1157# pylint: disable=redefined-outer-name
1158# pylint: disable=g-doc-args
1159@tf_export(v1=["cond"])
1160@dispatch.add_dispatch_support
1161@deprecation.deprecated_args(
1162    None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
1163    "fn1", "fn2")
1164def cond(pred,
1165         true_fn=None,
1166         false_fn=None,
1167         strict=False,
1168         name=None,
1169         fn1=None,
1170         fn2=None):
1171  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
1172
1173  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
1174  `false_fn` must have the same non-zero number and type of outputs.
1175
1176  **WARNING**: Any Tensors or Operations created outside of `true_fn` and
1177  `false_fn` will be executed regardless of which branch is selected at runtime.
1178
1179  Although this behavior is consistent with the dataflow model of TensorFlow,
1180  it has frequently surprised users who expected a lazier semantics.
1181  Consider the following simple program:
1182
1183  ```python
1184  z = tf.multiply(a, b)
1185  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
1186  ```
1187
1188  If `x < y`, the `tf.add` operation will be executed and `tf.square`
1189  operation will not be executed. Since `z` is needed for at least one
1190  branch of the `cond`, the `tf.multiply` operation is always executed,
1191  unconditionally.
1192
1193  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
1194  call to `cond`, and not at all during `Session.run()`). `cond`
1195  stitches together the graph fragments created during the `true_fn` and
1196  `false_fn` calls with some additional graph nodes to ensure that the right
1197  branch gets executed depending on the value of `pred`.
1198
1199  `tf.cond` supports nested structures as implemented in
1200  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
1201  same (possibly nested) value structure of lists, tuples, and/or named tuples.
1202  Singleton lists and tuples form the only exceptions to this: when returned by
1203  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
1204  This behavior is disabled by passing `strict=True`.
1205
1206  Args:
1207    pred: A scalar determining whether to return the result of `true_fn` or
1208      `false_fn`.
1209    true_fn: The callable to be performed if pred is true.
1210    false_fn: The callable to be performed if pred is false.
1211    strict: A boolean that enables/disables 'strict' mode; see above.
1212    name: Optional name prefix for the returned tensors.
1213
1214  Returns:
1215    Tensors returned by the call to either `true_fn` or `false_fn`. If the
1216    callables return a singleton list, the element is extracted from the list.
1217
1218  Raises:
1219    TypeError: if `true_fn` or `false_fn` is not callable.
1220    ValueError: if `true_fn` and `false_fn` do not return the same number of
1221      tensors, or return tensors of different types.
1222
1223  Example:
1224
1225  ```python
1226  x = tf.constant(2)
1227  y = tf.constant(5)
1228  def f1(): return tf.multiply(x, 17)
1229  def f2(): return tf.add(y, 23)
1230  r = tf.cond(tf.less(x, y), f1, f2)
1231  # r is set to f1().
1232  # Operations in f2 (e.g., tf.add) are not executed.
1233  ```
1234
1235  """
1236  # We needed to make true_fn/false_fn keyword arguments for
1237  # backwards-compatibility. This check exists so that we can convert back to
1238  # having them be positional arguments.
1239  # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
1240  # `fn1` and `fn2` are deleted.
1241  if fn1 is not None:
1242    if true_fn is not None:
1243      raise TypeError(
1244          "cond(): 'true_fn' and 'fn1' may not be set simultaneously.")
1245    true_fn = fn1
1246  elif true_fn is None:
1247    raise TypeError("cond(): 'true_fn' argument required")
1248  if fn2 is not None:
1249    if false_fn is not None:
1250      raise TypeError(
1251          "cond(): 'false_fn' and 'fn2' may not be set simultaneously.")
1252    false_fn = fn2
1253  elif false_fn is None:
1254    raise TypeError("cond(): 'false_fn' argument required")
1255
1256  if not callable(true_fn):
1257    raise TypeError("'true_fn' must be callable.")
1258  if not callable(false_fn):
1259    raise TypeError("'false_fn' must be callable.")
1260
1261  if context.executing_eagerly():
1262    return _eager_cond_implementation(pred, true_fn, false_fn, strict, name)
1263
1264  # Always enable control flow v2 if building a function, regardless of toggle.
1265  if util.EnableControlFlowV2(ops.get_default_graph()):
1266    return cond_v2.cond_v2(pred, true_fn, false_fn, name)
1267
1268  with ops.name_scope(name, "cond", [pred]):
1269    # Add the Switch to the graph.
1270    if isinstance(pred, bool):
1271      raise TypeError("'pred' must not be a Python bool.")
1272    p_2, p_1 = switch(pred, pred)
1273    pivot_1 = array_ops.identity(p_1, name="switch_t")
1274    pivot_2 = array_ops.identity(p_2, name="switch_f")
1275    pred = array_ops.identity(pred, name="pred_id")
1276    # Disable the fetching of tensors that are only on one branch of cond.
1277    for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:
1278      tensor.op.graph.prevent_fetching(tensor.op)
1279
1280    # Build the graph for the true branch in a new context.
1281    context_t = CondContext(pred, pivot_1, branch=1)
1282    try:
1283      context_t.Enter()
1284      orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
1285      if orig_res_t is None:
1286        raise ValueError("'true_fn' must have a return value.")
1287      context_t.ExitResult(res_t)
1288    finally:
1289      context_t.Exit()
1290
1291    # Build the graph for the false branch in a new context.
1292    context_f = CondContext(pred, pivot_2, branch=0)
1293    try:
1294      context_f.Enter()
1295      orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
1296      if orig_res_f is None:
1297        raise ValueError("'false_fn' must have a return value.")
1298      context_f.ExitResult(res_f)
1299    finally:
1300      context_f.Exit()
1301
1302    if not strict:
1303      orig_res_t = _UnpackIfSingleton(orig_res_t)
1304      orig_res_f = _UnpackIfSingleton(orig_res_f)
1305
1306    # Check that the return values of the two branches have the same structure.
1307    try:
1308      nest.assert_same_structure(orig_res_t, orig_res_f, expand_composites=True)
1309    except (TypeError, ValueError):
1310      nest.map_structure(_cast_indexed_slice_indices, orig_res_t, orig_res_f)
1311      nest.map_structure(_cast_indexed_slice_indices, res_t, res_f)
1312      try:
1313        nest.assert_same_structure(orig_res_t, orig_res_f,
1314                                   expand_composites=True)
1315      except TypeError as e:
1316        raise TypeError(
1317            f"Incompatible return types of 'true_fn' and 'false_fn': {e}")
1318      except ValueError as e:
1319        raise ValueError(
1320            f"Incompatible return values of 'true_fn' and 'false_fn': {e}")
1321
1322    # Add the final merge to the graph.
1323    if not res_t:
1324      raise ValueError(
1325          "'true_fn' and 'false_fn' must return at least one result.")
1326
1327    res_t_flat = nest.flatten(res_t, expand_composites=True)
1328    res_f_flat = nest.flatten(res_f, expand_composites=True)
1329
1330    for (x, y) in zip(res_t_flat, res_f_flat):
1331      assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
1332      if x.dtype.base_dtype != y.dtype.base_dtype:
1333        raise ValueError(
1334            "Outputs of 'true_fn' and 'false_fn' must have the same type(s). "
1335            f"Received {x.dtype.name} from 'true_fn' "
1336            f"and {y.dtype.name} from 'false_fn'.")
1337
1338    merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
1339    merges = _convert_flows_to_tensorarrays(
1340        nest.flatten(orig_res_t, expand_composites=True), merges)
1341
1342    # Only add non-nested conds to the collection. Any nested control flow will
1343    # be encapsulated in the root context.
1344    assert context_t.outer_context == context_f.outer_context
1345    if context_t.outer_context is None:
1346      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
1347      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
1348
1349    merges = nest.pack_sequence_as(
1350        structure=orig_res_t, flat_sequence=merges, expand_composites=True)
1351
1352    # Singleton lists and tuples are automatically unpacked if strict == False.
1353    if not strict:
1354      merges = _UnpackIfSingleton(merges)
1355    return merges
1356
1357
1358def _cast_indexed_slice_indices(a, b):
1359  """Cast IndexedSlice.indices from int32 to int64 where necessary.
1360
1361  If `a` and `b` are both IndexedSlices, and their indices have different
1362  dtypes, then cast both their dtypes to `int64` (modifies `a` and `b`
1363  in-place).  Otherwise, does nothing.
1364
1365  Args:
1366    a: A value, which may be an IndexedSlices.
1367    b: A value, which may be an IndexedSlices.
1368  """
1369  if (isinstance(a, ops.IndexedSlices) and isinstance(b, ops.IndexedSlices)
1370      and a.indices.dtype != b.indices.dtype):
1371    # pylint: disable=protected-access
1372    a._indices = math_ops.cast(a.indices, dtypes.int64)
1373    b._indices = math_ops.cast(b.indices, dtypes.int64)
1374
1375
1376# pylint: enable=g-doc-args
1377# pylint: enable=redefined-outer-name
1378
1379
1380@tf_export("cond", v1=[])
1381@dispatch.add_dispatch_support
1382def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None):
1383  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
1384
1385  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
1386  `false_fn` must have the same non-zero number and type of outputs.
1387
1388  **WARNING**: Any Tensors or Operations created outside of `true_fn` and
1389  `false_fn` will be executed regardless of which branch is selected at runtime.
1390
1391  Although this behavior is consistent with the dataflow model of TensorFlow,
1392  it has frequently surprised users who expected a lazier semantics.
1393  Consider the following simple program:
1394
1395  ```python
1396  z = tf.multiply(a, b)
1397  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
1398  ```
1399
1400  If `x < y`, the `tf.add` operation will be executed and `tf.square`
1401  operation will not be executed. Since `z` is needed for at least one
1402  branch of the `cond`, the `tf.multiply` operation is always executed,
1403  unconditionally.
1404
1405  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
1406  call to `cond`, and not at all during `Session.run()`). `cond`
1407  stitches together the graph fragments created during the `true_fn` and
1408  `false_fn` calls with some additional graph nodes to ensure that the right
1409  branch gets executed depending on the value of `pred`.
1410
1411  `tf.cond` supports nested structures as implemented in
1412  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
1413  same (possibly nested) value structure of lists, tuples, and/or named tuples.
1414  Singleton lists and tuples form the only exceptions to this: when returned by
1415  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
1416
1417  Note: It is illegal to "directly" use tensors created inside a cond branch
1418  outside it, e.g. by storing a reference to a branch tensor in the python
1419  state. If you need to use a tensor created in a branch function you should
1420  return it as an output of the branch function and use the output from
1421  `tf.cond` instead.
1422
1423  Args:
1424    pred: A scalar determining whether to return the result of `true_fn` or
1425      `false_fn`.
1426    true_fn: The callable to be performed if pred is true.
1427    false_fn: The callable to be performed if pred is false.
1428    name: Optional name prefix for the returned tensors.
1429
1430  Returns:
1431    Tensors returned by the call to either `true_fn` or `false_fn`. If the
1432    callables return a singleton list, the element is extracted from the list.
1433
1434  Raises:
1435    TypeError: if `true_fn` or `false_fn` is not callable.
1436    ValueError: if `true_fn` and `false_fn` do not return the same number of
1437      tensors, or return tensors of different types.
1438
1439  Example:
1440
1441  ```python
1442  x = tf.constant(2)
1443  y = tf.constant(5)
1444  def f1(): return tf.multiply(x, 17)
1445  def f2(): return tf.add(y, 23)
1446  r = tf.cond(tf.less(x, y), f1, f2)
1447  # r is set to f1().
1448  # Operations in f2 (e.g., tf.add) are not executed.
1449  ```
1450
1451  """
1452  return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
1453
1454
1455def _resource_safe_shape(t):
1456  """Returns the shape of t or the variable it points to."""
1457  if t.dtype == dtypes.resource:
1458    while t.op.inputs:
1459      t = t.op.inputs[0]
1460    return tensor_shape.TensorShape(t.op.get_attr("shape"))
1461  return array_ops.shape_internal(t, optimize=False)
1462
1463
1464# TODO(yuanbyu): Consider having a unified notion of context for
1465# not only conditionals and loops but also control dependency and
1466# subgraphs.
1467class WhileContext(ControlFlowContext):
1468  """The context for the loop construct."""
1469
1470  def __init__(self,
1471               maximum_iterations=None,
1472               parallel_iterations=10,
1473               back_prop=True,
1474               swap_memory=False,
1475               name="while_context",
1476               grad_state=None,
1477               context_def=None,
1478               import_scope=None):
1479    """"Creates a `WhileContext`.
1480
1481    Args:
1482      maximum_iterations: Optional upper bound on number of loop iterations.
1483      parallel_iterations: The number of iterations allowed to run in parallel.
1484      back_prop: Whether backprop is enabled for this while loop.
1485      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
1486      name: Optional name prefix for the returned tensors.
1487      grad_state: The gradient loop state.
1488      context_def: Optional `WhileContextDef` protocol buffer to initialize the
1489        `Whilecontext` python object from.
1490      import_scope: Optional `string`. Name scope to add. Only used when
1491        initialing from protocol buffer.
1492    """
1493    if context_def:
1494      self._init_from_proto(context_def, import_scope=import_scope)
1495    else:
1496      ControlFlowContext.__init__(self)
1497      self._init_from_args(maximum_iterations, parallel_iterations, back_prop,
1498                           swap_memory, name)
1499    # The gradient loop state.
1500    self._grad_state = grad_state
1501
1502  def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop,
1503                      swap_memory, name):
1504    """Creates a new `WhileContext` from arguments.
1505
1506    Args:
1507      maximum_iterations: Optional upper bound on number of loop iterations.
1508      parallel_iterations: The number of iterations allowed to run in parallel.
1509      back_prop: Whether backprop is enabled for this while loop.
1510      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
1511      name: Optional name prefix for the returned tensors.
1512
1513    Raises:
1514      ValueError: If `parallel_iterations` has invalid value.
1515    """
1516    if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
1517      raise ValueError("'parallel_iterations' must be a positive integer: "
1518                       "%s" % parallel_iterations)
1519    self._name = ops.get_default_graph().unique_name(name)
1520    self._maximum_iterations = maximum_iterations
1521    self._parallel_iterations = parallel_iterations
1522    self._back_prop = back_prop
1523    self._swap_memory = swap_memory
1524    # We use this node to control constants created by the pred lambda.
1525    self._pivot_for_pred = None
1526    # We use this node to control constants created by the body lambda.
1527    self._pivot_for_body = None
1528    # The boolean tensor for loop termination condition. Used in code
1529    # generation for gradient computation
1530    self._pivot = None
1531    # The list of exit tensors for loop variables.
1532    self._loop_exits = []
1533    # The list of enter tensors for loop variables.
1534    self._loop_enters = []
1535    self._graph = ops.get_default_graph()
1536
1537  def _init_from_proto(self, context_def, import_scope=None):
1538    """Creates a new `WhileContext` from protocol buffer.
1539
1540    Args:
1541      context_def: `WhileContextDef` protocol buffer.
1542      import_scope: Optional `string`. Name scope to add.
1543    """
1544    assert isinstance(context_def, control_flow_pb2.WhileContextDef)
1545    # Create from context_def.
1546    g = ops.get_default_graph()
1547    self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
1548    if context_def.maximum_iterations_name:
1549      self._maximum_iterations = g.as_graph_element(
1550          ops.prepend_name_scope(context_def.maximum_iterations_name,
1551                                 import_scope))
1552    else:
1553      self._maximum_iterations = None
1554    self._parallel_iterations = context_def.parallel_iterations
1555    self._back_prop = context_def.back_prop
1556    self._swap_memory = context_def.swap_memory
1557    self._pivot_for_pred = g.as_graph_element(
1558        ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
1559    # We use this node to control constants created by the body lambda.
1560    self._pivot_for_body = g.as_graph_element(
1561        ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
1562    # The boolean tensor for loop termination condition. Used in code
1563    # generation for gradient computation.
1564    self._pivot = g.as_graph_element(
1565        ops.prepend_name_scope(context_def.pivot_name, import_scope))
1566    # The list of exit tensors for loop variables.
1567    self._loop_exits = [
1568        g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
1569        for exit_name in context_def.loop_exit_names
1570    ]
1571    # The list of enter tensors for loop variables.
1572    self._loop_enters = [
1573        g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
1574        for enter_name in context_def.loop_enter_names
1575    ]
1576    super(WhileContext, self).__init__(
1577        values_def=context_def.values_def, import_scope=import_scope)
1578
1579    # import_scope causes self.name to be different from the original serialized
1580    # context's name. Rewrite "frame_name" attrs with the new name.
1581    if import_scope:
1582      for tensor_name in self._values:
1583        op = g.as_graph_element(tensor_name).op
1584        if util.IsLoopEnter(op):
1585          # pylint: disable=protected-access
1586          op._set_attr("frame_name",
1587                       attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
1588          # pylint: enable=protected-access
1589    self._graph = ops.get_default_graph()
1590
1591  @property
1592  def maximum_iterations(self):
1593    """The maximum number of iterations that will be executed."""
1594    return self._maximum_iterations
1595
1596  @property
1597  def parallel_iterations(self):
1598    """The number of iterations allowed to run in parallel."""
1599    return self._parallel_iterations
1600
1601  @property
1602  def back_prop(self):
1603    """True iff backprop is enabled for this while loop."""
1604    return self._back_prop
1605
1606  @property
1607  def swap_memory(self):
1608    """True iff GPU-CPU memory swap is enabled for this while loop."""
1609    return self._swap_memory
1610
1611  @property
1612  def pivot(self):
1613    """The boolean tensor representing the loop termination condition."""
1614    return self._pivot
1615
1616  @property
1617  def loop_enters(self):
1618    """The list of enter tensors for loop variables."""
1619    return self._loop_enters
1620
1621  @property
1622  def loop_exits(self):
1623    """The list of exit tensors for loop variables."""
1624    return self._loop_exits
1625
1626  @property
1627  def grad_state(self):
1628    """The gradient loop state."""
1629    return self._grad_state
1630
1631  def to_proto(self, export_scope=None):
1632    """Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
1633
1634    Args:
1635      export_scope: Optional `string`. Name scope to remove.
1636
1637    Returns:
1638      A `WhileContextDef` protocol buffer.
1639    """
1640    if (export_scope is None or self.name.startswith(export_scope)):
1641      context_def = control_flow_pb2.WhileContextDef()
1642      context_def.context_name = ops.strip_name_scope(self.name, export_scope)
1643      context_def.parallel_iterations = self._parallel_iterations
1644      if self._maximum_iterations is not None:
1645        context_def.maximum_iterations_name = ops.strip_name_scope(
1646            self._maximum_iterations.name, export_scope)
1647      context_def.back_prop = self._back_prop
1648      context_def.swap_memory = self._swap_memory
1649      context_def.pivot_for_pred_name = ops.strip_name_scope(
1650          self._pivot_for_pred.name, export_scope)
1651      context_def.pivot_for_body_name = ops.strip_name_scope(
1652          self._pivot_for_body.name, export_scope)
1653      context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
1654                                                    export_scope)
1655      context_def.loop_exit_names.extend([
1656          ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
1657      ])
1658      context_def.loop_enter_names.extend([
1659          ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
1660      ])
1661      context_def.values_def.MergeFrom(
1662          super(WhileContext, self)._to_values_def(export_scope=export_scope))
1663      for nested in self._nested_contexts:
1664        nested_def = context_def.nested_contexts.add()
1665        nested.to_control_flow_context_def(nested_def)
1666
1667      return context_def
1668    else:
1669      return None
1670
1671  def to_control_flow_context_def(self, context_def, export_scope=None):
1672    context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
1673
1674  @staticmethod
1675  def from_proto(context_def, import_scope=None):
1676    """Returns a `WhileContext` object created from `context_def`.
1677
1678    Args:
1679      context_def: A `WhileContextDef` protocol buffer.
1680      import_scope: Optional `string`. Name scope to add.
1681
1682    Returns:
1683      A `WhileContext` Python object.
1684    """
1685    ret = WhileContext(context_def=context_def, import_scope=import_scope)
1686    ret.Enter()
1687    for nested_def in context_def.nested_contexts:
1688      from_control_flow_context_def(nested_def, import_scope=import_scope)
1689    ret.Exit()
1690    return ret
1691
1692  def GetWhileContext(self):
1693    return self
1694
1695  def GetControlPivot(self):
1696    if self._pivot_for_body is not None:
1697      return self._pivot_for_body
1698    return self._pivot_for_pred
1699
1700  def AddValue(self, val):
1701    """Add `val` to the current context and its outer context recursively."""
1702    result = val
1703    new_value = val.name not in self._values
1704    # Don't treat ops in this context as new values. Usually all known values
1705    # are in self._values, except when we're importing a while loop inside this
1706    # WhileContext. Since there's a cycle in this case, `val` may be part of the
1707    # imported while loop but not yet processed by this context and added to
1708    # self._values in _AddOpInternal. We only want to process external input
1709    # tensors to the while loop here.
1710    new_value &= val.op._control_flow_context is not self  # pylint: disable=protected-access
1711    if new_value:
1712      self._values.add(val.name)
1713
1714      # If we are in a grad context and val is from its forward context,
1715      # use GetRealValue(), which adds the logic to save the history of
1716      # val in forward.
1717      grad_ctxt = ops.get_default_graph()._get_control_flow_context()
1718      if grad_ctxt:
1719        grad_ctxt = grad_ctxt.GetWhileContext()
1720        if grad_ctxt.grad_state:
1721          forward_ctxt = util.GetWhileContext(val.op)
1722          if util.IsLoopExit(val.op):
1723            forward_ctxt = forward_ctxt.outer_context
1724            if forward_ctxt:
1725              forward_ctxt = forward_ctxt.GetWhileContext()
1726          if forward_ctxt == grad_ctxt.grad_state.forward_context:
1727            real_val = grad_ctxt.grad_state.GetRealValue(val)
1728            self._external_values[val.name] = real_val
1729            return real_val
1730
1731      if self._outer_context is not None:
1732        result = self._outer_context.AddValue(val)
1733      # Create an Enter to make `result` known to this loop context.
1734      with ops.control_dependencies(None):
1735        enter = _Enter(
1736            result,
1737            self._name,
1738            is_constant=True,
1739            parallel_iterations=self._parallel_iterations)
1740        enter.graph.prevent_feeding(enter)
1741        if self._outer_context:
1742          self._outer_context.AddInnerOp(enter.op)
1743      # Fix the control inputs and control flow context of these enter ops.
1744      self._FixControlInputsAndContext([enter])
1745
1746      # Add `enter` in this context.
1747      self._values.add(enter.name)
1748      self._external_values[val.name] = enter
1749      result = enter
1750    else:
1751      actual_val = self._external_values.get(val.name)
1752      if actual_val is not None:
1753        result = actual_val
1754    return result
1755
1756  def AddOp(self, op):
1757    """Add `op` to the current context."""
1758    # For a reduction op, if op is in a grad context and its input is from
1759    # its forward context, moving op to the forward context means we would
1760    # store the tensor after the reduction as opposed to the tensor before
1761    # reduction, and therefore could significantly reduce memory consumption.
1762    # For now, we do this only for a few ops.
1763    #
1764    # If in XLA context, do not move constant ops to forward pass as pushing to
1765    # and popping from a stack removes the constant property of an op and breaks
1766    # XLA compilation, which requires certain inputs to be constant for certain
1767    # ops.
1768    if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}:
1769      grad_ctxt = ops.get_default_graph()._get_control_flow_context()
1770      if grad_ctxt:
1771        grad_ctxt = grad_ctxt.GetWhileContext()
1772        if grad_ctxt.grad_state:
1773          op_input_forward_ctxt = util.GetWhileContext(op.inputs[0].op)
1774          if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
1775            op_input_ctxt = op.inputs[0].op._get_control_flow_context()
1776            op._set_control_flow_context(op_input_ctxt)
1777            op_input_ctxt._AddOpInternal(op)
1778            return
1779    self._AddOpInternal(op)
1780
1781  def _AddOpInternal(self, op):
1782    """Add `op` to the current context.
1783
1784    We move any external control dependencies of the op to the loop pivot, to
1785    ensure they get executed.
1786    """
1787    # This is needed to prevent frame mismatch errors where there are Const
1788    # nodes inside tf.function in v1 while_loop and inlining is turned on.
1789    if op.type in ["PartitionedCall", "StatefulPartitionedCall"]:
1790      op._add_control_input(self.GetControlPivot().op)  # pylint: disable=protected-access
1791    if not op.inputs:
1792      # Remove any external control dependency on this op
1793      control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
1794      # Add a control edge from the control pivot to this op.
1795      if not control_inputs:
1796        # pylint: disable=protected-access
1797        op._add_control_input(self.GetControlPivot().op)
1798        # pylint: enable=protected-access
1799      for x in op.outputs:
1800        self._values.add(x.name)
1801    else:
1802      for index in range(len(op.inputs)):
1803        x = op.inputs[index]
1804        real_x = self.AddValue(x)
1805        if real_x != x:
1806          op._update_input(index, real_x)  # pylint: disable=protected-access
1807      # Remove any external control dependency on this op.
1808      _, external_inputs = self._RemoveExternalControlEdges(op)
1809      # Add a control dependency to prevent loop invariants from
1810      # enabling ops that should not be executed.
1811      self._MaybeAddControlDependency(op)
1812      for x in op.outputs:
1813        self._values.add(x.name)
1814    if external_inputs:
1815      # Use an identity to pull control inputs as data inputs. Note that we
1816      # ignore ops which don't have outputs. TODO(apassos): fix that
1817      with ops.control_dependencies(None):
1818        self.Enter()
1819        external_inputs = [
1820            array_ops.identity(x.outputs[0]).op
1821            for x in external_inputs
1822            if x.outputs
1823        ]
1824        self.Exit()
1825      op._add_control_inputs(external_inputs)  # pylint: disable=protected-access
1826    if self._outer_context or not util.IsLoopExit(op):
1827      op.graph.prevent_fetching(op)
1828      for x in op.outputs:
1829        op.graph.prevent_feeding(x)
1830
1831    if self._outer_context:
1832      self._outer_context.AddInnerOp(op)
1833
1834  def _MaybeAddControlDependency(self, op):
1835    """Add a control input to the op if it only depends on loop invariants."""
1836
1837    def _IsOpFree(op):
1838      """Determines if `op` needs a control dependency."""
1839      if op.control_inputs:
1840        return False
1841      # pylint: disable=protected-access
1842      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
1843        return True
1844      # pylint: enable=protected-access
1845      for x in op.inputs:
1846        if not util.IsLoopConstantEnter(x.op):
1847          return False
1848      return True
1849
1850    if _IsOpFree(op):
1851      # pylint: disable=protected-access
1852      op._add_control_input(self.GetControlPivot().op)
1853      # pylint: enable=protected-access
1854
1855  def AddForwardLoopCounter(self, outer_grad_state):
1856    """Adds a loop that counts the number of iterations.
1857
1858    This is added to the forward loop at the time when we start to
1859    create the loop for backprop gradient computation. Called in
1860    the outer context of this forward context.
1861
1862    The pseudocode is:
1863      `n = 0; while (_pivot) { n++; }`
1864
1865    Note that a control dependency is added to `n` to ensure the correct
1866    execution order of stack push ops.
1867
1868    Args:
1869      outer_grad_state: The outer grad state. None if not nested.
1870
1871    Returns:
1872      The number of iterations taken by the forward loop and the loop index.
1873    """
1874    n = constant_op.constant(0, name="f_count")
1875    if outer_grad_state is not None:
1876      # Force the stack pushes of i-th execution of an inner loop to be ordered
1877      # before the pushes of (i+1)-th execution of the same inner loop.
1878      outer_add_op = outer_grad_state.forward_index.op.inputs[0].op
1879      n.op._add_control_input(outer_add_op)  # pylint: disable=protected-access
1880
1881    self.Enter()
1882    self.AddName(n.name)
1883    enter_n = _Enter(
1884        n,
1885        self._name,
1886        is_constant=False,
1887        parallel_iterations=self._parallel_iterations,
1888        name="f_count")
1889    self.loop_enters.append(enter_n)
1890
1891    merge_n = merge([enter_n, enter_n])[0]
1892    switch_n = switch(merge_n, self._pivot)
1893
1894    index = math_ops.add(switch_n[1], 1)
1895    next_n = _NextIteration(index)
1896    merge_n.op._update_input(1, next_n)
1897
1898    total_iterations = exit(switch_n[0], name="f_count")
1899    self.loop_exits.append(total_iterations)
1900    self.ExitResult([total_iterations])
1901    self.Exit()
1902    return total_iterations, next_n
1903
1904  def AddBackpropLoopCounter(self, count, outer_grad_state):
1905    """Add the backprop loop that controls the iterations.
1906
1907    This is added to the backprop loop. It is used to control the loop
1908    termination of the backprop loop. Called in the outer context of
1909    this grad context.
1910
1911    The pseudocode is:
1912      `n = count; while (n >= 1) { n--; }`
1913
1914    Note that a control dependency is added to `final_zero` to ensure the
1915    correct execution order of stack pop ops.
1916
1917    Args:
1918      count: The number of iterations for backprop.
1919      outer_grad_state: The outer grad state. None if not nested.
1920
1921    Returns:
1922      The loop index.
1923    """
1924    in_separate_functions = count.graph is not ops.get_default_graph()
1925    if in_separate_functions:
1926      # Brings the count into this graph
1927      count = array_ops.identity(count)
1928    else:
1929      # TODO(apassos) XLA expects this constant to be created outside the loop,
1930      # so doing that for now.
1931      one = constant_op.constant(1, name="b_count")
1932
1933    self.Enter()
1934    self.AddName(count.name)
1935    enter_count = _Enter(
1936        count,
1937        self._name,
1938        is_constant=False,
1939        parallel_iterations=self._parallel_iterations,
1940        name="b_count")
1941    self.loop_enters.append(enter_count)
1942
1943    merge_count = merge([enter_count, enter_count])[0]
1944    self._pivot_for_pred = merge_count
1945
1946    if in_separate_functions:
1947      one = constant_op.constant(1, name="b_count")
1948    pred = math_ops.greater_equal(merge_count, one)
1949    self._pivot = loop_cond(pred, name="b_count")
1950    switch_count = switch(merge_count, self._pivot)
1951
1952    index = math_ops.subtract(switch_count[1], one)
1953    self._pivot_for_body = index
1954    next_count = _NextIteration(index)
1955    merge_count.op._update_input(1, next_count)
1956
1957    final_zero = exit(switch_count[0], name="b_count")
1958    self.loop_exits.append(final_zero)
1959    if outer_grad_state is not None:
1960      # Force the stack pops of i-th execution of an inner loop to be ordered
1961      # before the pops of (i+1)-th execution of the same inner loop.
1962      # pylint: disable=protected-access
1963      outer_grad_state.grad_sync._add_control_input(final_zero.op)
1964      # pylint: enable=protected-access
1965
1966    self.ExitResult([final_zero])
1967    self.Exit()
1968    return next_count
1969
1970  def AddBackpropAccumulator(self, op, grad):
1971    """Add an accumulation loop for every loop invariant.
1972
1973    This is added to the backprop loop. It is used to accumulate partial
1974    gradients within each loop iteration. Called when in the gradient while
1975    context.
1976
1977    The pseudocode is:
1978      ```
1979      acc = 0.0;
1980      while (_pivot) {
1981        acc += grad;
1982      }
1983      ```
1984
1985    Args:
1986      op: The Enter op for a loop invariant.
1987      grad: The partial gradient of an iteration for a loop invariant.
1988
1989    Returns:
1990      The gradient for a loop invariant.
1991    """
1992    self.Exit()
1993    # Create a zeros tensor with the right shape for acc. If we don't
1994    # know the full shape statically, we will have to get the shape
1995    # dynamically from the forward inference. Getting the shape right
1996    # for the zeros is only needed for the base case when the loop exits
1997    # without running any iterations.
1998    shape = grad.get_shape()
1999    if shape.is_fully_defined():
2000      if self.outer_context:
2001        self.outer_context.Enter()
2002      acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
2003      if self.outer_context:
2004        self.outer_context.Exit()
2005    else:
2006      value = op.inputs[0]
2007      if (isinstance(self.outer_context, WhileContext) and
2008          self.outer_context.grad_state is not None):
2009        # We are in a nested while loop.
2010        forward_ctxt = self.grad_state.forward_context
2011        forward_ctxt.outer_context.Enter()
2012        zeros_shape = array_ops.shape_internal(value, optimize=False)
2013        forward_ctxt.outer_context.Exit()
2014        outer_grad_state = self.grad_state.outer_grad_state
2015        history_zeros_shape = outer_grad_state.AddForwardAccumulator(
2016            zeros_shape)
2017        self.outer_context.Enter()
2018        real_shape = outer_grad_state.AddBackpropAccumulatedValue(
2019            history_zeros_shape, zeros_shape)
2020        acc = array_ops.zeros(real_shape, grad.dtype)
2021        self.outer_context.Exit()
2022      else:
2023        if self.outer_context:
2024          self.outer_context.Enter()
2025        zeros_shape = array_ops.shape_internal(value, optimize=False)
2026        acc = array_ops.zeros(zeros_shape, grad.dtype)
2027        if self.outer_context:
2028          self.outer_context.Exit()
2029
2030    self.Enter()
2031    self.AddName(acc.name)
2032    enter_acc = _Enter(
2033        acc,
2034        self._name,
2035        is_constant=False,
2036        parallel_iterations=self._parallel_iterations,
2037        name="b_acc")
2038    self.loop_enters.append(enter_acc)
2039
2040    merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
2041    switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
2042
2043    add_acc = math_ops.add(switch_acc_true, grad)
2044    next_acc = _NextIteration(add_acc)
2045    merge_acc.op._update_input(1, next_acc)  # pylint: disable=protected-access
2046
2047    result_acc = exit(switch_acc_false, name="b_acc")
2048    self.loop_exits.append(result_acc)
2049    self.ExitResult([result_acc])
2050    return result_acc
2051
2052  def AddBackpropIndexedSlicesAccumulator(self, op, grad):
2053    """This is used for accumulating gradients that are IndexedSlices.
2054
2055    This is essentially the equivalent of AddBackpropAccumulator but optimized
2056    for things like updating embeddings from within a while loop.
2057
2058    Args:
2059      op: The Enter op for a loop invariant.
2060      grad: The partial gradients represented as an IndexedSlices.
2061
2062    Returns:
2063      The accumulated IndexedSlices gradient of the loop invariant.
2064    """
2065    values = grad.values
2066    indices = grad.indices
2067    dense_shape = grad.dense_shape
2068
2069    self.Exit()
2070    if self.outer_context:
2071      self.outer_context.Enter()
2072    if values.get_shape().is_fully_defined():
2073      values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] +
2074                                              values.get_shape().dims[1:])
2075      if self.outer_context:
2076        self.outer_context.Enter()
2077      values_acc = constant_op.constant(
2078          0, values.dtype, shape=values_shape, name="b_acc")
2079      if self.outer_context:
2080        self.outer_context.Exit()
2081    else:
2082      values_shape = _resource_safe_shape(op.inputs[0])[1:]
2083      values_shape = array_ops.concat([[1], values_shape], 0)
2084      values_acc = array_ops.zeros(values_shape, dtype=values.dtype)
2085    indices_acc = constant_op.constant([0], indices.dtype)
2086    shape_acc = None
2087    if dense_shape is not None:
2088      if dense_shape.get_shape().is_fully_defined():
2089        if self.outer_context:
2090          self.outer_context.Enter()
2091        shape_acc = constant_op.constant(
2092            0, dense_shape.dtype, shape=dense_shape.get_shape())
2093        if self.outer_context:
2094          self.outer_context.Exit()
2095      else:
2096        shape_acc = array_ops.zeros_like(
2097            array_ops.shape_internal(
2098                op.inputs[0], optimize=False, out_type=dense_shape.dtype),
2099            optimize=False)
2100
2101    if self.outer_context:
2102      self.outer_context.Exit()
2103
2104    self.Enter()
2105    self.AddName(values_acc.name)
2106    self.AddName(indices_acc.name)
2107    init_acc = [indices_acc, values_acc]
2108    if shape_acc is not None:
2109      self.AddName(shape_acc.name)
2110      init_acc.append(shape_acc)
2111
2112    # Set use_input_shape=False since the accumulator tensors will grow in
2113    # size. If use_input_shape=True, the _update_input call below will result in
2114    # incompatible shapes.
2115    enter_acc = [
2116        _Enter(
2117            x,
2118            self._name,
2119            is_constant=False,
2120            parallel_iterations=self._parallel_iterations,
2121            use_input_shape=False,
2122            name="b_acc") for x in init_acc
2123    ]
2124    # Manually set appropriate partial shapes.
2125    enter_acc[0].set_shape([None])
2126    if values_acc.shape.dims is not None:
2127      enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:])
2128    self.loop_enters.extend(enter_acc)
2129
2130    merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
2131    switch_acc = [switch(x, self._pivot) for x in merge_acc]
2132
2133    # The actual accumulation.
2134    acc_indexed_slices = [
2135        array_ops.concat([xa[1], xv], 0)
2136        for xa, xv in zip(switch_acc[:2], [indices, values])
2137    ]
2138    if shape_acc is not None:
2139      # For the shape we just keep the maximum
2140      acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
2141
2142    next_acc = [_NextIteration(x) for x in acc_indexed_slices]
2143    for xm, xn in zip(merge_acc, next_acc):
2144      xm.op._update_input(1, xn)  # pylint: disable=protected-access
2145
2146    exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]
2147    self.loop_exits.extend(exit_acc)
2148
2149    self.ExitResult(exit_acc)
2150    return ops.IndexedSlices(
2151        indices=exit_acc[0],
2152        values=exit_acc[1],
2153        dense_shape=exit_acc[2] if shape_acc is not None else None)
2154
2155  def _InitializeValues(self, values):
2156    """Makes the values known to this context."""
2157    self._values = set()
2158    for x in values:
2159      if isinstance(x, ops.Tensor):
2160        self._values.add(x.name)
2161      else:
2162        raise TypeError("'values' must be a list of Tensors. "
2163                        f"Received: {type(x)}.")
2164
2165  def _BuildLoop(self, pred, body, original_loop_vars, loop_vars,
2166                 shape_invariants):
2167    """Core: Add the loop termination condition and body to the graph."""
2168    flat_loop_vars = nest.flatten(original_loop_vars, expand_composites=True)
2169
2170    # Let the context know the loop variables so the loop variables
2171    # would be added in the outer contexts properly.
2172    self._InitializeValues(loop_vars)
2173    real_vars = loop_vars
2174    if self._outer_context:
2175      real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
2176    with ops.control_dependencies(None):
2177      enter_vars = [
2178          _Enter(
2179              x,
2180              self._name,
2181              is_constant=False,
2182              parallel_iterations=self._parallel_iterations,
2183              use_input_shape=(shape_invariants is None)) for x in real_vars
2184      ]
2185      for x in enter_vars:
2186        x.graph.prevent_feeding(x)
2187        if self._outer_context:
2188          self._outer_context.AddInnerOp(x.op)
2189
2190    # Finds the closest enclosing non-None control pivot.
2191    outer_context = self._outer_context
2192    control_pivot = None
2193    while outer_context is not None and control_pivot is None:
2194      control_pivot = outer_context.GetControlPivot()
2195      # pylint: disable=protected-access
2196      outer_context = outer_context._outer_context
2197      # pylint: enable=protected-access
2198
2199    if control_pivot is not None:
2200      for var in enter_vars:
2201        if util.IsLoopConstantEnter(var.op.inputs[0].op):
2202          # pylint: disable=protected-access
2203          var.op._add_control_input(control_pivot.op)
2204          # pylint: enable=protected-access
2205    _SetShapeInvariants(real_vars, enter_vars, shape_invariants)
2206
2207    # Fix the control inputs and control flow context of these enter ops.
2208    self._FixControlInputsAndContext(enter_vars)
2209    self._InitializeValues(enter_vars)
2210    self._loop_enters = enter_vars
2211
2212    merge_vars = [merge([x, x])[0] for x in enter_vars]
2213    self._pivot_for_pred = merge_vars[0]
2214
2215    # Build the graph for pred.
2216    merge_vars_with_tensor_arrays = (
2217        _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars))
2218    packed_vars = nest.pack_sequence_as(
2219        structure=original_loop_vars,
2220        flat_sequence=merge_vars_with_tensor_arrays,
2221        expand_composites=True)
2222    c = ops.convert_to_tensor(pred(*packed_vars))
2223    self._pivot = loop_cond(c, name="LoopCond")
2224    switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
2225
2226    # Build the graph for body.
2227    vars_for_body = [_Identity(x[1]) for x in switch_vars]
2228    self._pivot_for_body = vars_for_body[0]
2229    # Convert TensorArray flow variables inside the context back into
2230    # their associated TensorArrays for calling the body.
2231    vars_for_body_with_tensor_arrays = (
2232        _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body))
2233    packed_vars_for_body = nest.pack_sequence_as(
2234        structure=original_loop_vars,
2235        flat_sequence=vars_for_body_with_tensor_arrays,
2236        expand_composites=True)
2237    pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2238    body_result = body(*packed_vars_for_body)
2239    post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2240    if not nest.is_sequence_or_composite(body_result):
2241      body_result = [body_result]
2242    if len(post_summaries) > len(pre_summaries):
2243      new_summaries = post_summaries[len(pre_summaries):]
2244      summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2245      summary_ref[:] = pre_summaries
2246      with ops.control_dependencies(new_summaries):
2247
2248        def map_fn(x):
2249          # TODO(apassos) figure out how to trigger with tensor arrays as well
2250          if isinstance(x, tensor_array_ops.TensorArray):
2251            return x
2252          return array_ops.identity(x)
2253
2254        body_result = nest.map_structure(
2255            map_fn, body_result, expand_composites=True)
2256
2257    # Compare the structure types of input and output of body.
2258    # For backwards compatibility, the first layer is forced to a list
2259    # during this comparison, because inputs are typically lists and
2260    # outputs of the body are typically tuples.
2261    nest.assert_same_structure(
2262        list(packed_vars_for_body), list(body_result), expand_composites=True)
2263
2264    # Store body_result to keep track of TensorArrays returned by body
2265    original_body_result = body_result
2266    # Convert TensorArrays returned by body into their flow variables
2267    result = nest.map_structure(
2268        _convert_tensorarray_to_flow,
2269        nest.flatten(body_result, expand_composites=True),
2270        expand_composites=True)
2271    result = ops.convert_n_to_tensor_or_composite(result)
2272
2273    # Add NextIteration and the back edges to complete the loop.
2274    if len(merge_vars) != len(result):
2275      raise ValueError("Number of inputs and outputs of 'body' must match "
2276                       f"'loop_vars'. Got {len(merge_vars)} for the number of "
2277                       f"inputs/outputs, and {len(result)} for 'loop_vars'.")
2278    next_vars = []
2279    for m, v in zip(merge_vars, result):
2280      next_vars.append(_AddNextAndBackEdge(m, v))
2281
2282    # Add the exit ops.
2283    exit_vars = [exit(x[0]) for x in switch_vars]
2284    self._loop_exits = exit_vars
2285
2286    # Exit the loop.
2287    self.ExitResult(exit_vars)
2288
2289    return original_body_result, exit_vars
2290
2291  def BuildLoop(self, pred, body, loop_vars, shape_invariants,
2292                return_same_structure):
2293    """Add the loop termination condition and body to the graph."""
2294
2295    # Keep original_loop_vars to identify which are TensorArrays
2296    original_loop_vars = loop_vars
2297    # Convert TensorArrays to their flow variables
2298    loop_vars = nest.map_structure(
2299        _convert_tensorarray_to_flow,
2300        nest.flatten(loop_vars, expand_composites=False),
2301        expand_composites=True)
2302    loop_vars = ops.convert_n_to_tensor_or_composite(loop_vars)
2303    if shape_invariants is None:
2304      shape_invariants = nest.map_structure(
2305          _get_shape_invariant, loop_vars, expand_composites=False)
2306    loop_vars = nest.flatten(loop_vars, expand_composites=True)
2307    try:
2308      self.Enter()
2309      # _BuildLoop calls _update_input in several places. _mutation_lock()
2310      # ensures a Session.run call cannot occur between creating and mutating
2311      # new ops.
2312      with ops.get_default_graph()._mutation_lock():  # pylint: disable=protected-access
2313        original_body_result, exit_vars = self._BuildLoop(
2314            pred, body, original_loop_vars, loop_vars, shape_invariants)
2315    finally:
2316      self.Exit()
2317
2318    flat_result = nest.flatten(original_body_result, expand_composites=True)
2319    # Convert TensorArray flow variables outside the context back into
2320    # their associated TensorArrays for returning to caller.
2321    exit_vars_with_tensor_arrays = (
2322        _convert_flows_to_tensorarrays(flat_result, exit_vars))
2323    packed_exit_vars = nest.pack_sequence_as(
2324        structure=original_body_result,
2325        flat_sequence=exit_vars_with_tensor_arrays,
2326        expand_composites=True)
2327
2328    if return_same_structure:
2329      return packed_exit_vars
2330    else:
2331      return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
2332
2333  def _FixControlInputsAndContext(self, enters):
2334    graph = ops.get_default_graph()
2335    # pylint: disable=protected-access
2336    for e in enters:
2337      if isinstance(e, ops.Tensor):
2338        xs = [e]
2339      else:
2340        raise TypeError("'enters' must be a list of Tensors. "
2341                        f"Received: {type(e)}.")
2342      for x in xs:
2343        inp_op = x.op.inputs[0].op
2344        control_inputs = graph._control_dependencies_for_inputs([inp_op])
2345        outer_control_inputs = []
2346        for op in control_inputs:
2347          # We need to keep control inputs that are in any ancestor
2348          # ControlFlowContext, and within outer WhileContext.
2349          keep_as_control_input = True
2350          op_ctxt = util.GetOutputContext(op)
2351          outer_ctxt = self.outer_context
2352          outer_while_context = (None if outer_ctxt is None else
2353                                 outer_ctxt.GetWhileContext())
2354          while outer_ctxt != op_ctxt:
2355            if outer_ctxt is None or outer_ctxt == outer_while_context:
2356              keep_as_control_input = False
2357              break
2358            outer_ctxt = outer_ctxt.outer_context
2359          if keep_as_control_input:
2360            outer_control_inputs.append(op)
2361        x.op._set_control_flow_context(self)
2362        x.op._add_control_inputs(outer_control_inputs)
2363        graph._record_op_seen_by_control_dependencies(x.op)
2364    # pylint: enable=protected-access
2365
2366  def IsWhileContext(self):
2367    return True
2368
2369
2370# @TODO(b/133606651) Replace "shape_invariants" with "loop_vars_signature".
2371# pylint: disable=redefined-outer-name
2372@tf_export("while_loop", v1=[])
2373@deprecation.deprecated_arg_values(
2374    None,
2375    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
2376Instead of:
2377results = tf.while_loop(c, b, vars, back_prop=False)
2378Use:
2379results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))""",
2380    warn_once=True,
2381    back_prop=False)
2382def while_loop_v2(cond,
2383                  body,
2384                  loop_vars,
2385                  shape_invariants=None,
2386                  parallel_iterations=10,
2387                  back_prop=True,
2388                  swap_memory=False,
2389                  maximum_iterations=None,
2390                  name=None):
2391  """Repeat `body` while the condition `cond` is true.
2392
2393  `cond` is a callable returning a boolean scalar tensor. `body` is a callable
2394  returning a (possibly nested) tuple, namedtuple or list of tensors of the same
2395  arity (length and structure) and types as `loop_vars`. `loop_vars` is a
2396  (possibly nested) tuple, namedtuple or list of tensors that is passed to both
2397  `cond` and `body`. `cond` and `body` both take as many arguments as there are
2398  `loop_vars`.
2399
2400  In addition to regular Tensors or IndexedSlices, the body may accept and
2401  return TensorArray objects.  The flows of the TensorArray objects will
2402  be appropriately forwarded between loops and during gradient calculations.
2403
2404  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
2405  call to `while_loop`, and not at all during `Session.run()`). `while_loop`
2406  stitches together the graph fragments created during the `cond` and `body`
2407  calls with some additional graph nodes to create the graph flow that
2408  repeats `body` until `cond` returns false.
2409
2410  For correctness, `tf.while_loop()` strictly enforces shape invariants for
2411  the loop variables. A shape invariant is a (possibly partial) shape that
2412  is unchanged across the iterations of the loop. An error will be raised
2413  if the shape of a loop variable after an iteration is determined to be more
2414  general than or incompatible with its shape invariant. For example, a shape
2415  of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
2416  compatible with [11, 17]. By default (if the argument `shape_invariants` is
2417  not specified), it is assumed that the initial shape of each tensor in
2418  `loop_vars` is the same in every iteration. The `shape_invariants` argument
2419  allows the caller to specify a less specific shape invariant for each loop
2420  variable, which is needed if the shape varies between iterations. The
2421  `tf.Tensor.set_shape`
2422  function may also be used in the `body` function to indicate that
2423  the output loop variable has a particular shape. The shape invariant for
2424  SparseTensor and IndexedSlices are treated specially as follows:
2425
2426  a) If a loop variable is a SparseTensor, the shape invariant must be
2427  TensorShape([r]) where r is the rank of the dense tensor represented
2428  by the sparse tensor. It means the shapes of the three tensors of the
2429  SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
2430  is the shape of the SparseTensor.dense_shape property. It must be the shape of
2431  a vector.
2432
2433  b) If a loop variable is an IndexedSlices, the shape invariant must be
2434  a shape invariant of the values tensor of the IndexedSlices. It means
2435  the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
2436  [shape.ndims]).
2437
2438  `while_loop` implements non-strict semantics, enabling multiple iterations
2439  to run in parallel. The maximum number of parallel iterations can be
2440  controlled by `parallel_iterations`, which gives users some control over
2441  memory consumption and execution order. For correct programs, `while_loop`
2442  should return the same result for any parallel_iterations > 0.
2443
2444  For training, TensorFlow stores the tensors that are produced in the
2445  forward inference and are needed in back propagation. These tensors are a
2446  main source of memory consumption and often cause OOM errors when training
2447  on GPUs. When the flag swap_memory is true, we swap out these tensors from
2448  GPU to CPU. This for example allows us to train RNN models with very long
2449  sequences and large batches.
2450
2451  Args:
2452    cond: A callable that represents the termination condition of the loop.
2453    body: A callable that represents the loop body.
2454    loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
2455      `Tensor`, and `TensorArray` objects.
2456    shape_invariants: The shape invariants for the loop variables.
2457    parallel_iterations: The number of iterations allowed to run in parallel. It
2458      must be a positive integer.
2459    back_prop: (optional) Deprecated. False disables support for back
2460      propagation. Prefer using `tf.stop_gradient` instead.
2461    swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
2462    maximum_iterations: Optional maximum number of iterations of the while loop
2463      to run.  If provided, the `cond` output is AND-ed with an additional
2464      condition ensuring the number of iterations executed is no greater than
2465      `maximum_iterations`.
2466    name: Optional name prefix for the returned tensors.
2467
2468  Returns:
2469    The output tensors for the loop variables after the loop. The return value
2470      has the same structure as `loop_vars`.
2471
2472  Raises:
2473    TypeError: if `cond` or `body` is not callable.
2474    ValueError: if `loop_vars` is empty.
2475
2476  Example:
2477
2478  ```python
2479  i = tf.constant(0)
2480  c = lambda i: tf.less(i, 10)
2481  b = lambda i: (tf.add(i, 1), )
2482  r = tf.while_loop(c, b, [i])
2483  ```
2484
2485  Example with nesting and a namedtuple:
2486
2487  ```python
2488  import collections
2489  Pair = collections.namedtuple('Pair', 'j, k')
2490  ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
2491  c = lambda i, p: i < 10
2492  b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
2493  ijk_final = tf.while_loop(c, b, ijk_0)
2494  ```
2495
2496  Example using shape_invariants:
2497
2498  ```python
2499  i0 = tf.constant(0)
2500  m0 = tf.ones([2, 2])
2501  c = lambda i, m: i < 10
2502  b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
2503  tf.while_loop(
2504      c, b, loop_vars=[i0, m0],
2505      shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
2506  ```
2507
2508  Example which demonstrates non-strict semantics: In the following
2509  example, the final value of the counter `i` does not depend on `x`. So
2510  the `while_loop` can increment the counter parallel to updates of `x`.
2511  However, because the loop counter at one loop iteration depends
2512  on the value at the previous iteration, the loop counter itself cannot
2513  be incremented in parallel. Hence if we just want the final value of the
2514  counter (which we print on the line `print(sess.run(i))`), then
2515  `x` will never be incremented, but the counter will be updated on a
2516  single thread. Conversely, if we want the value of the output (which we
2517  print on the line `print(sess.run(out).shape)`), then the counter may be
2518  incremented on its own thread, while `x` can be incremented in
2519  parallel on a separate thread. In the extreme case, it is conceivable
2520  that the thread incrementing the counter runs until completion before
2521  `x` is incremented even a single time. The only thing that can never
2522  happen is that the thread updating `x` can never get ahead of the
2523  counter thread because the thread incrementing `x` depends on the value
2524  of the counter.
2525
2526  ```python
2527  import tensorflow as tf
2528
2529  n = 10000
2530  x = tf.constant(list(range(n)))
2531  c = lambda i, x: i < n
2532  b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
2533  [i], "x:"))
2534  i, out = tf.while_loop(c, b, (0, x))
2535  with tf.compat.v1.Session() as sess:
2536      print(sess.run(i))  # prints [0] ... [9999]
2537
2538      # The following line may increment the counter and x in parallel.
2539      # The counter thread may get ahead of the other thread, but not the
2540      # other way around. So you may see things like
2541      # [9996] x:[9987]
2542      # meaning that the counter thread is on iteration 9996,
2543      # while the other thread is on iteration 9987
2544      print(sess.run(out).shape)
2545  ```
2546
2547  """
2548  return while_loop(
2549      cond=cond,
2550      body=body,
2551      loop_vars=loop_vars,
2552      shape_invariants=shape_invariants,
2553      parallel_iterations=parallel_iterations,
2554      back_prop=back_prop,
2555      swap_memory=swap_memory,
2556      name=name,
2557      maximum_iterations=maximum_iterations,
2558      return_same_structure=True)
2559
2560
2561# pylint: disable=redefined-outer-name
2562@tf_export(v1=["while_loop"])
2563def while_loop(cond,
2564               body,
2565               loop_vars,
2566               shape_invariants=None,
2567               parallel_iterations=10,
2568               back_prop=True,
2569               swap_memory=False,
2570               name=None,
2571               maximum_iterations=None,
2572               return_same_structure=False):
2573  """Repeat `body` while the condition `cond` is true.
2574
2575  `cond` is a callable returning a boolean scalar tensor. `body` is a callable
2576  returning a (possibly nested) tuple, namedtuple or list of tensors of the same
2577  arity (length and structure) and types as `loop_vars`. `loop_vars` is a
2578  (possibly nested) tuple, namedtuple or list of tensors that is passed to both
2579  `cond` and `body`. `cond` and `body` both take as many arguments as there are
2580  `loop_vars`.
2581
2582  In addition to regular Tensors or IndexedSlices, the body may accept and
2583  return TensorArray objects.  The flows of the TensorArray objects will
2584  be appropriately forwarded between loops and during gradient calculations.
2585
2586  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
2587  call to `while_loop`, and not at all during `Session.run()`). `while_loop`
2588  stitches together the graph fragments created during the `cond` and `body`
2589  calls with some additional graph nodes to create the graph flow that
2590  repeats `body` until `cond` returns false.
2591
2592  For correctness, `tf.while_loop()` strictly enforces shape invariants for
2593  the loop variables. A shape invariant is a (possibly partial) shape that
2594  is unchanged across the iterations of the loop. An error will be raised
2595  if the shape of a loop variable after an iteration is determined to be more
2596  general than or incompatible with its shape invariant. For example, a shape
2597  of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
2598  compatible with [11, 17]. By default (if the argument `shape_invariants` is
2599  not specified), it is assumed that the initial shape of each tensor in
2600  `loop_vars` is the same in every iteration. The `shape_invariants` argument
2601  allows the caller to specify a less specific shape invariant for each loop
2602  variable, which is needed if the shape varies between iterations. The
2603  `tf.Tensor.set_shape`
2604  function may also be used in the `body` function to indicate that
2605  the output loop variable has a particular shape. The shape invariant for
2606  SparseTensor and IndexedSlices are treated specially as follows:
2607
2608  a) If a loop variable is a SparseTensor, the shape invariant must be
2609  TensorShape([r]) where r is the rank of the dense tensor represented
2610  by the sparse tensor. It means the shapes of the three tensors of the
2611  SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
2612  is the shape of the SparseTensor.dense_shape property. It must be the shape of
2613  a vector.
2614
2615  b) If a loop variable is an IndexedSlices, the shape invariant must be
2616  a shape invariant of the values tensor of the IndexedSlices. It means
2617  the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
2618  [shape.ndims]).
2619
2620  `while_loop` implements non-strict semantics, enabling multiple iterations
2621  to run in parallel. The maximum number of parallel iterations can be
2622  controlled by `parallel_iterations`, which gives users some control over
2623  memory consumption and execution order. For correct programs, `while_loop`
2624  should return the same result for any parallel_iterations > 0.
2625
2626  For training, TensorFlow stores the tensors that are produced in the
2627  forward inference and are needed in back propagation. These tensors are a
2628  main source of memory consumption and often cause OOM errors when training
2629  on GPUs. When the flag swap_memory is true, we swap out these tensors from
2630  GPU to CPU. This for example allows us to train RNN models with very long
2631  sequences and large batches.
2632
2633  Args:
2634    cond: A callable that represents the termination condition of the loop.
2635    body: A callable that represents the loop body.
2636    loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
2637      `Tensor`, and `TensorArray` objects.
2638    shape_invariants: The shape invariants for the loop variables.
2639    parallel_iterations: The number of iterations allowed to run in parallel. It
2640      must be a positive integer.
2641    back_prop: Whether backprop is enabled for this while loop.
2642    swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
2643    name: Optional name prefix for the returned tensors.
2644    maximum_iterations: Optional maximum number of iterations of the while loop
2645      to run.  If provided, the `cond` output is AND-ed with an additional
2646      condition ensuring the number of iterations executed is no greater than
2647      `maximum_iterations`.
2648    return_same_structure: If True, output has same structure as `loop_vars`. If
2649      eager execution is enabled, this is ignored (and always treated as True).
2650
2651  Returns:
2652    The output tensors for the loop variables after the loop.
2653     If `return_same_structure` is True, the return value has the same
2654     structure as `loop_vars`.
2655     If `return_same_structure` is False, the return value is a Tensor,
2656     TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list
2657     otherwise.
2658
2659  Raises:
2660    TypeError: if `cond` or `body` is not callable.
2661    ValueError: if `loop_vars` is empty.
2662
2663  Example:
2664
2665  ```python
2666  i = tf.constant(0)
2667  c = lambda i: tf.less(i, 10)
2668  b = lambda i: tf.add(i, 1)
2669  r = tf.while_loop(c, b, [i])
2670  ```
2671
2672  Example with nesting and a namedtuple:
2673
2674  ```python
2675  import collections
2676  Pair = collections.namedtuple('Pair', 'j, k')
2677  ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
2678  c = lambda i, p: i < 10
2679  b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
2680  ijk_final = tf.while_loop(c, b, ijk_0)
2681  ```
2682
2683  Example using shape_invariants:
2684
2685  ```python
2686  i0 = tf.constant(0)
2687  m0 = tf.ones([2, 2])
2688  c = lambda i, m: i < 10
2689  b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
2690  tf.while_loop(
2691      c, b, loop_vars=[i0, m0],
2692      shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
2693  ```
2694
2695  Example which demonstrates non-strict semantics: In the following
2696  example, the final value of the counter `i` does not depend on `x`. So
2697  the `while_loop` can increment the counter parallel to updates of `x`.
2698  However, because the loop counter at one loop iteration depends
2699  on the value at the previous iteration, the loop counter itself cannot
2700  be incremented in parallel. Hence if we just want the final value of the
2701  counter (which we print on the line `print(sess.run(i))`), then
2702  `x` will never be incremented, but the counter will be updated on a
2703  single thread. Conversely, if we want the value of the output (which we
2704  print on the line `print(sess.run(out).shape)`), then the counter may be
2705  incremented on its own thread, while `x` can be incremented in
2706  parallel on a separate thread. In the extreme case, it is conceivable
2707  that the thread incrementing the counter runs until completion before
2708  `x` is incremented even a single time. The only thing that can never
2709  happen is that the thread updating `x` can never get ahead of the
2710  counter thread because the thread incrementing `x` depends on the value
2711  of the counter.
2712
2713  ```python
2714  import tensorflow as tf
2715
2716  n = 10000
2717  x = tf.constant(list(range(n)))
2718  c = lambda i, x: i < n
2719  b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
2720  [i], "x:"))
2721  i, out = tf.while_loop(c, b, (0, x))
2722  with tf.compat.v1.Session() as sess:
2723      print(sess.run(i))  # prints [0] ... [9999]
2724
2725      # The following line may increment the counter and x in parallel.
2726      # The counter thread may get ahead of the other thread, but not the
2727      # other way around. So you may see things like
2728      # [9996] x:[9987]
2729      # meaning that the counter thread is on iteration 9996,
2730      # while the other thread is on iteration 9987
2731      print(sess.run(out).shape)
2732  ```
2733
2734  """
2735  if not callable(cond):
2736    raise TypeError("'cond' must be callable.")
2737  if not callable(body):
2738    raise TypeError("'body' must be callable.")
2739  if parallel_iterations < 1:
2740    raise TypeError("'parallel_iterations' must be a positive integer.")
2741
2742  # Always enable control flow v2 if building a function, regardless of toggle.
2743  executing_eagerly = context.executing_eagerly()
2744  if (util.EnableControlFlowV2(ops.get_default_graph()) and
2745      not executing_eagerly):
2746    return while_v2.while_loop(
2747        cond,
2748        body,
2749        loop_vars,
2750        shape_invariants=shape_invariants,
2751        parallel_iterations=parallel_iterations,
2752        maximum_iterations=maximum_iterations,
2753        name=name,
2754        return_same_structure=return_same_structure,
2755        back_prop=back_prop)
2756
2757  with ops.name_scope(name, "while", loop_vars):
2758    if not loop_vars:
2759      raise ValueError("'loop_vars' must be provided.")
2760    try_to_pack = (len(loop_vars) == 1 and not return_same_structure)
2761    if maximum_iterations is not None:
2762      maximum_iterations = ops.convert_to_tensor(
2763          maximum_iterations, name="maximum_iterations")
2764      if maximum_iterations.shape.ndims != 0:
2765        raise ValueError(
2766            "'maximum_iterations' must be a scalar. "
2767            f"Received shape: {maximum_iterations.shape}")
2768
2769      if executing_eagerly:
2770        counter = 0
2771        maximum_iterations = int(maximum_iterations.numpy())
2772      else:
2773        counter = constant_op.constant(
2774            0, dtype=maximum_iterations.dtype, name="iteration_counter")
2775      orig_cond = cond
2776      orig_body = body
2777      if try_to_pack:
2778        loop_vars = (counter, loop_vars[0])
2779        cond = lambda i, lv: (  # pylint: disable=g-long-lambda
2780            math_ops.logical_and(i < maximum_iterations, orig_cond(lv)))
2781        body = lambda i, lv: (i + 1, orig_body(lv))
2782      else:
2783        loop_vars = (counter, loop_vars)
2784        cond = lambda i, lv: (  # pylint: disable=g-long-lambda
2785            math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
2786        body = lambda i, lv: (i + 1, orig_body(*lv))
2787      try_to_pack = False
2788
2789    if executing_eagerly:
2790      packed = False  # whether the body result was packed into a 1-item tuple
2791
2792      loop_var_structure = nest.map_structure(type_spec.type_spec_from_value,
2793                                              list(loop_vars))
2794      while cond(*loop_vars):
2795        loop_vars = body(*loop_vars)
2796        if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2797          packed = True
2798          loop_vars = (loop_vars,)
2799        nest.assert_same_structure(loop_var_structure, list(loop_vars))
2800
2801      def convert(x):
2802        if isinstance(x, tensor_array_ops.TensorArray):
2803          return x
2804        return ops.convert_to_tensor(x)
2805
2806      loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True)
2807      if maximum_iterations is not None:
2808        return loop_vars[1]
2809      else:
2810        return loop_vars[0] if packed else loop_vars
2811
2812    if shape_invariants is not None:
2813      if maximum_iterations is not None:
2814        shape_invariants = (tensor_shape.TensorShape([]), shape_invariants)
2815
2816      nest.assert_same_structure(
2817          loop_vars, shape_invariants, expand_composites=False)
2818      shape_invariants = nest.map_structure(
2819          _get_shape_invariant,
2820          loop_vars,
2821          shape_invariants,
2822          expand_composites=False)
2823
2824    loop_context = WhileContext(
2825        maximum_iterations=maximum_iterations,
2826        parallel_iterations=parallel_iterations,
2827        back_prop=back_prop,
2828        swap_memory=swap_memory)
2829    # Only add non-nested loops to the collection. Any nested control flow will
2830    # be encapsulated in the root context.
2831    if loop_context.outer_context is None:
2832      ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
2833    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
2834                                    return_same_structure)
2835    if maximum_iterations is not None:
2836      return result[1]
2837    else:
2838      return result
2839
2840
2841# pylint: enable=redefined-outer-name
2842
2843
2844def _AsTensorList(x, p):
2845  """Return x as a list of Tensors or IndexedSlices.
2846
2847  For entries of `x` that are Operations, this returns an Identity of `p`
2848  with a dependency on the operation.
2849
2850  Args:
2851    x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
2852    p: A Tensor to return for entries in `x` that are Operations.
2853
2854  Returns:
2855    A list of Tensors or IndexedSlices.
2856  """
2857  if not isinstance(x, (list, _basetuple)):
2858    x = [x]
2859
2860  l = []
2861  for v in x:
2862    if isinstance(v, ops.Operation):
2863      v = with_dependencies([v], p)
2864    v = ops.convert_to_tensor_or_composite(v)
2865    if isinstance(v, ops.Tensor):
2866      l.append(array_ops.identity(v))
2867    else:
2868      l.append(
2869          ops.IndexedSlices(
2870              array_ops.identity(v.values), array_ops.identity(v.indices)))
2871  return l
2872
2873
2874def _CheckResults(a, b):
2875  assert len(a) == len(b), (
2876      "Values returned by a() and b() must have the same length.")
2877  for x, y in zip(a, b):
2878    assert x.dtype == y.dtype, (
2879        "Values returned by a() [%s] and b() [%s] must have "
2880        "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
2881
2882
2883def with_dependencies(dependencies, output_tensor, name=None):
2884  """Produces the content of `output_tensor` only after `dependencies`.
2885
2886  In some cases, a user may want the output of an operation to be
2887  consumed externally only after some other dependencies have run
2888  first. This function ensures returns `output_tensor`, but only after all
2889  operations in `dependencies` have run. Note that this means that there is
2890  no guarantee that `output_tensor` will be evaluated after any `dependencies`
2891  have run.
2892
2893  See also `tf.tuple` and `tf.group`.
2894
2895  Args:
2896    dependencies: Iterable of operations to run before this op finishes.
2897    output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
2898    name: (Optional) A name for this operation.
2899
2900  Returns:
2901    Same as `output_tensor`.
2902
2903  Raises:
2904    TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
2905  """
2906  if context.executing_eagerly():
2907    return output_tensor
2908  with ops.name_scope(name, "control_dependency",
2909                      list(dependencies) + [output_tensor]) as name:
2910    with ops.colocate_with(output_tensor):
2911      with ops.control_dependencies(dependencies):
2912        output_tensor = ops.convert_to_tensor_or_composite(output_tensor)
2913        if isinstance(output_tensor, ops.Tensor):
2914          return _Identity(output_tensor, name=name)
2915        else:
2916          return ops.IndexedSlices(
2917              _Identity(output_tensor.values, name=name), output_tensor.indices,
2918              output_tensor.dense_shape)
2919
2920
2921def _GroupControlDeps(dev, deps, name=None):
2922  with ops.control_dependencies(deps):
2923    if dev is None:
2924      return no_op(name=name)
2925    else:
2926      with ops.device(dev):
2927        return no_op(name=name)
2928
2929
2930# TODO(touts): Accept "inputs" as a list.
2931@tf_export("group")
2932def group(*inputs, **kwargs):
2933  """Create an op that groups multiple operations.
2934
2935  When this op finishes, all ops in `inputs` have finished. This op has no
2936  output.
2937
2938  Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
2939  this method, as ops execute in the expected order thanks to automatic control
2940  dependencies.* Only use `tf.group` when working with v1
2941  `tf.Graph` code.
2942
2943  When operating in a v1-style graph context, ops are not executed in the same
2944  order as specified in the code; TensorFlow will attempt to execute ops in
2945  parallel or in an order convenient to the result it is computing.  `tf.group`
2946  allows you to request that one or more results finish before execution
2947  continues.
2948
2949  `tf.group` creates a single op (of type `NoOp`), and then adds appropriate
2950  control dependencies.  Thus, `c = tf.group(a, b)` will compute the same graph
2951  as this:
2952
2953      with tf.control_dependencies([a, b]):
2954          c = tf.no_op()
2955
2956  See also `tf.tuple` and
2957  `tf.control_dependencies`.
2958
2959  Args:
2960    *inputs: Zero or more tensors to group.
2961    name: A name for this operation (optional).
2962
2963  Returns:
2964    An Operation that executes all its inputs.
2965
2966  Raises:
2967    ValueError: If an unknown keyword argument is provided.
2968  """
2969  if context.executing_eagerly():
2970    return None
2971  name = kwargs.pop("name", None)
2972  if kwargs:
2973    raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
2974  with ops.name_scope(name, "group_deps", inputs) as name:
2975    # Grouping no inputs means do nothing
2976    if not inputs:
2977      return no_op(name=name)
2978
2979    # Sorts *inputs according to their devices.
2980    ops_on_device = {}  # device -> operations specified on the device.
2981    for inp in nest.flatten(inputs, expand_composites=True):
2982      if not hasattr(inp, "device"):
2983        raise TypeError("'inputs' should be zero or more (nested) Tensors. "
2984                        f"Received '{inp}' with type '{type(inp)}'.")
2985      dev = inp.device
2986      if dev in ops_on_device:
2987        ops_on_device[dev].append(inp)
2988      else:
2989        ops_on_device[dev] = [inp]
2990    if len(ops_on_device) == 1:
2991      # 1-level tree. The root node is the returned NoOp node.
2992      (dev, deps), = ops_on_device.items()
2993      return _GroupControlDeps(dev, deps, name=name)
2994
2995    # 2-level tree. The root node is the returned NoOp node.
2996    # deps contains 1 NoOp node for each device.
2997    deps = []
2998
2999    def device_key(dev):
3000      """A sort key that allows None to be compared to strings."""
3001      return "" if dev is None else dev
3002
3003    for dev in sorted(ops_on_device, key=device_key):
3004      deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
3005
3006    with ops.control_dependencies(deps):
3007      return no_op(name=name)
3008
3009
3010@tf_export("tuple", v1=[])
3011@dispatch.add_dispatch_support
3012def tuple_v2(tensors, control_inputs=None, name=None):
3013  """Groups tensors together.
3014
3015  The returned tensors have the same value as the input tensors, but they
3016  are computed only after all the input tensors have been computed.
3017
3018  Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
3019  this method, as ops execute in the expected order thanks to automatic control
3020  dependencies.* Only use `tf.tuple` when working with v1 `tf.Graph` code.
3021
3022  See also `tf.group` and `tf.control_dependencies`.
3023
3024  Args:
3025    tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
3026    control_inputs: List of additional ops to finish before returning.
3027    name: (optional) A name to use as a `name_scope` for the operation.
3028
3029  Returns:
3030    Same as `tensors`.
3031
3032  Raises:
3033    ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
3034    TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
3035      objects.
3036
3037  """
3038  return tuple(tensors=tensors, name=name, control_inputs=control_inputs)  # pylint: disable=redefined-builtin
3039
3040
3041@tf_export(v1=["tuple"])
3042@dispatch.add_dispatch_support
3043def tuple(tensors, name=None, control_inputs=None):  # pylint: disable=redefined-builtin
3044  """Group tensors together.
3045
3046  This creates a tuple of tensors with the same values as the `tensors`
3047  argument, except that the value of each tensor is only returned after the
3048  values of all tensors have been computed.
3049
3050  `control_inputs` contains additional ops that have to finish before this op
3051  finishes, but whose outputs are not returned.
3052
3053  This can be used as a "join" mechanism for parallel computations: all the
3054  argument tensors can be computed in parallel, but the values of any tensor
3055  returned by `tuple` are only available after all the parallel computations
3056  are done.
3057
3058  See also `tf.group` and
3059  `tf.control_dependencies`.
3060
3061  Args:
3062    tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
3063    name: (optional) A name to use as a `name_scope` for the operation.
3064    control_inputs: List of additional ops to finish before returning.
3065
3066  Returns:
3067    Same as `tensors`.
3068
3069  Raises:
3070    ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
3071    TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
3072      objects.
3073
3074  """
3075  if context.executing_eagerly():
3076    return tensors
3077  with ops.name_scope(name, "tuple", tensors) as name:
3078    tensors = [
3079        t if (isinstance(t, ops.Operation) or tensor_util.is_tf_type(t) or
3080              t is None) else ops.convert_to_tensor(t) for t in tensors
3081    ]
3082    gating_ops = [
3083        t if isinstance(t, ops.Operation) else t.op
3084        for t in tensors
3085        if t is not None
3086    ]
3087    if control_inputs:
3088      for c in control_inputs:
3089        if isinstance(c, ops.Tensor):
3090          c = c.op
3091        elif not isinstance(c, ops.Operation):
3092          raise TypeError(
3093              "'control_inputs' must only contain Operation or Tensor. "
3094              f"Received: {type(c)}")
3095        gating_ops.append(c)
3096    # Note that in order to ensure ordering in the pbtxt, we must take care to
3097    # ensure the order here.
3098    gating_ops = sorted(set(gating_ops), key=lambda op: op._id)  # Uniquify ops.
3099    if not gating_ops:
3100      raise ValueError("'tensors' must have at least one Tensor. "
3101                       f"Received: {tensors}.")
3102    gate = group(*gating_ops)
3103    tpl = []
3104    for t in tensors:
3105      if tensor_util.is_tf_type(t):
3106        tpl.append(with_dependencies([gate], t))
3107      elif isinstance(t, ops.Operation):
3108        with ops.control_dependencies([gate]):
3109          tpl.append(group(t))
3110      else:
3111        tpl.append(None)
3112    return tpl
3113
3114
3115def _assert_at_most_n_true(predicates, n, msg):
3116  """Returns an Assert op that checks that at most n predicates are True.
3117
3118  Args:
3119    predicates: list of bool scalar tensors.
3120    n: maximum number of true predicates allowed.
3121    msg: Error message.
3122  """
3123  preds_c = array_ops.stack(predicates, name="preds_c")
3124  num_true_conditions = math_ops.reduce_sum(
3125      math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
3126  condition = math_ops.less_equal(num_true_conditions,
3127                                  constant_op.constant(n, name="n_true_conds"))
3128  preds_names = ", ".join(getattr(p, "name", "?") for p in predicates)
3129  error_msg = [
3130      "%s: more than %d conditions (%s) evaluated as True:" %
3131      (msg, n, preds_names), preds_c
3132  ]
3133  return Assert(condition, data=error_msg, summarize=len(predicates))
3134
3135
3136def _case_create_default_action(predicates, actions):
3137  """Creates default action for a list of actions and their predicates.
3138
3139  It uses the input actions to select an arbitrary as default and makes sure
3140  that corresponding predicates have valid values.
3141
3142  Args:
3143    predicates: a list of bool scalar tensors
3144    actions: a list of callable objects which return tensors.
3145
3146  Returns:
3147    a callable
3148  """
3149  k = len(predicates) - 1  # could pick any
3150  predicate, action = predicates[k], actions[k]
3151  other_predicates, other_actions = predicates[:k], actions[:k]
3152
3153  def default_action():
3154    others_msg = ("Implementation error: "
3155                  "selected default action #%d was called, but some of other "
3156                  "predicates are True: " % k)
3157    default_msg = ("Input error: "
3158                   "None of conditions evaluated as True:",
3159                   array_ops.stack(predicates, name="preds_c"))
3160    with ops.control_dependencies([
3161        _assert_at_most_n_true(other_predicates, n=0, msg=others_msg),
3162        Assert(predicate, data=default_msg)
3163    ]):
3164      return action()
3165
3166  return default_action, other_predicates, other_actions
3167
3168
3169def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name,
3170                                       allow_python_preds):
3171  """Verifies input arguments for the case function.
3172
3173  Args:
3174    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
3175      callable which returns a list of tensors.
3176    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3177    name: A name for the case operation.
3178    allow_python_preds: if true, pred_fn_pairs may contain Python bools in
3179      addition to boolean Tensors
3180
3181  Raises:
3182    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3183    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3184    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3185               callable.
3186
3187  Returns:
3188    a tuple <list of scalar bool tensors, list of callables>.
3189  """
3190  if not isinstance(pred_fn_pairs, (list, _basetuple, dict)):
3191    raise TypeError("'pred_fn_pairs' must be a list, tuple, or dict. "
3192                    f"Received: {type(pred_fn_pairs)}")
3193
3194  if isinstance(pred_fn_pairs, collections.OrderedDict):
3195    pred_fn_pairs = pred_fn_pairs.items()
3196  elif isinstance(pred_fn_pairs, dict):
3197    if context.executing_eagerly():
3198      # No name to sort on in eager mode. Use dictionary traversal order,
3199      # which is nondeterministic in versions of Python < 3.6
3200      if not exclusive:
3201        raise ValueError("Unordered dictionaries are not supported for the "
3202                         "'pred_fn_pairs' argument when `exclusive=False` and "
3203                         "eager mode is enabled.")
3204      pred_fn_pairs = list(pred_fn_pairs.items())
3205    else:
3206      pred_fn_pairs = sorted(
3207          pred_fn_pairs.items(), key=lambda item: item[0].name)
3208      if not exclusive:
3209        logging.warn(
3210            "%s: An unordered dictionary of predicate/fn pairs was "
3211            "provided, but exclusive=False. The order of conditional "
3212            "tests is deterministic but not guaranteed.", name)
3213  for pred_fn_pair in pred_fn_pairs:
3214    if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2:
3215      raise TypeError("Each entry in 'pred_fn_pairs' must be a 2-tuple. "
3216                      f"Received {pred_fn_pair}.")
3217    pred, fn = pred_fn_pair
3218
3219    if isinstance(pred, ops.Tensor):
3220      if pred.dtype != dtypes.bool:
3221        raise TypeError("pred must be Tensor of type bool: %s" % pred.name)
3222    elif not allow_python_preds:
3223      raise TypeError("pred must be a Tensor, got: %s" % pred)
3224    elif not isinstance(pred, bool):
3225      raise TypeError("pred must be a Tensor or bool, got: %s" % pred)
3226
3227    if not callable(fn):
3228      raise TypeError("fn for pred %s must be callable." % pred.name)
3229
3230  predicates, actions = zip(*pred_fn_pairs)
3231  return predicates, actions
3232
3233
3234def _case_helper(cond_fn,
3235                 pred_fn_pairs,
3236                 default,
3237                 exclusive,
3238                 name,
3239                 allow_python_preds=False,
3240                 **cond_kwargs):
3241  """Implementation of case that allows for different cond functions.
3242
3243  Args:
3244    cond_fn: method that has signature and semantics of `cond` above.
3245    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
3246      callable which returns a list of tensors.
3247    default: Optional callable that returns a list of tensors.
3248    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3249    name: A name for this operation (optional).
3250    allow_python_preds: if true, pred_fn_pairs may contain Python bools in
3251      addition to boolean Tensors
3252    **cond_kwargs: keyword arguments that will be passed to `cond_fn`.
3253
3254  Returns:
3255    The tensors returned by the first pair whose predicate evaluated to True, or
3256    those returned by `default` if none does.
3257
3258  Raises:
3259    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3260    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3261    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3262               callable.
3263  """
3264  predicates, actions = _case_verify_and_canonicalize_args(
3265      pred_fn_pairs, exclusive, name, allow_python_preds)
3266  with ops.name_scope(name, "case", [predicates]):
3267    if default is None:
3268      default, predicates, actions = _case_create_default_action(
3269          predicates, actions)
3270    fn = default
3271    # To eval conditions in direct order we create nested conditions in reverse:
3272    #   cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...))
3273    for predicate, action in reversed(list(zip(predicates, actions))):
3274      fn = functools.partial(
3275          cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs)
3276    if exclusive:
3277      with ops.control_dependencies([
3278          _assert_at_most_n_true(
3279              predicates, n=1, msg="Input error: exclusive=True")
3280      ]):
3281        return fn()
3282    else:
3283      return fn()
3284
3285
3286def _indexed_case_verify_and_canonicalize_args(branch_fns, default,
3287                                               branch_index):
3288  """Verifies input arguments for the case function.
3289
3290  Args:
3291    branch_fns: Dict or list of pairs of an `int` and a callable which
3292      returns a list of tensors.
3293    default: Optional callable that returns a list of tensors.
3294    branch_index: Optional int `Tensor`, which selects for the corresponding
3295      pred_fn_pair.
3296
3297  Raises:
3298    TypeError: If `branch_fns` is not a list/dictionary.
3299    TypeError: If `branch_fns` is a list but does not contain 2-tuples or
3300               callables.
3301    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3302               callable.
3303
3304  Returns:
3305    branch_fns: validated list of callables for each branch (default last).
3306  """
3307  if not isinstance(branch_index, ops.Tensor):
3308    raise TypeError("'branch_index' must be a Tensor, got {}".format(
3309        type(branch_index)))
3310  if not branch_index.dtype.is_integer:
3311    raise TypeError("'branch_index' must be an integer Tensor, got {}".format(
3312        branch_index.dtype))
3313
3314  if not branch_fns:
3315    raise ValueError("Must provide at least one item in 'branch_fns'")
3316  if not isinstance(branch_fns, (list, _basetuple, dict)):
3317    raise TypeError("'branch_fns' must be a list, tuple, or dict")
3318
3319  if isinstance(branch_fns, dict):
3320    branch_fns = branch_fns.items()
3321
3322  if all(callable(fn) for fn in branch_fns):
3323    branch_fns = list(enumerate(branch_fns))
3324
3325  for key_fn_pair in branch_fns:
3326    if not isinstance(key_fn_pair, _basetuple) or len(key_fn_pair) != 2:
3327      raise TypeError("Each entry in 'branch_fns' must be a 2-tuple. "
3328                      f"Received {key_fn_pair}.")
3329    key, branch_fn = key_fn_pair
3330
3331    if not isinstance(key, int):
3332      raise TypeError("key must be a Python `int`, got {}".format(type(key)))
3333
3334    if not callable(branch_fn):
3335      raise TypeError("fn for key {} must be callable.".format(key))
3336
3337  keys = [p[0] for p in branch_fns]
3338  if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys):
3339    raise ValueError(
3340        "branch indices (keys) must form contiguous range of [0 to {}) but "
3341        "found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys)))))
3342  actions = [p[1] for p in sorted(branch_fns)]
3343  if default is not None:
3344    actions.append(default)
3345  return actions
3346
3347
3348def _indexed_case_helper(branch_fns,
3349                         default,
3350                         branch_index,
3351                         name,
3352                         lower_using_switch_merge=None):
3353  """Implementation of case that emits the n-way indexed Case op.
3354
3355  Args:
3356    branch_fns: Dict or list of pairs of a boolean scalar tensor, and a
3357      callable which returns a list of tensors.
3358    default: Optional callable that returns a list of tensors.
3359    branch_index: Optional int `Tensor`, which selects for the corresponding
3360      pred_fn_pair.
3361    name: A name for this operation (optional).
3362    lower_using_switch_merge: Lower this op using switch merge ops (optional).
3363
3364  Returns:
3365    The tensors returned by the pair whose key matched branch_index, or
3366    those returned by `default` if none does.
3367
3368  Raises:
3369    TypeError: If `branch_fns` is not a list/dictionary.
3370    TypeError: If `branch_fns` is a list but does not contain 2-tuples or
3371               callables.
3372    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3373               callable.
3374  """
3375  branch_fns = _indexed_case_verify_and_canonicalize_args(
3376      branch_fns, default, branch_index)
3377  with ops.name_scope(name, "case", [branch_index]):
3378    if context.executing_eagerly() and not hasattr(branch_index, "graph"):
3379      branch_index = array_ops.where(
3380          math_ops.less(branch_index, 0)
3381          | math_ops.greater_equal(branch_index, len(branch_fns)),
3382          len(branch_fns) - 1, branch_index)
3383      return branch_fns[int(branch_index)]()
3384    return cond_v2.indexed_case(
3385        branch_index,
3386        branch_fns,
3387        lower_using_switch_merge=lower_using_switch_merge)
3388
3389
3390@tf_export("case", v1=[])
3391@dispatch.add_dispatch_support
3392def case_v2(pred_fn_pairs,
3393            default=None,
3394            exclusive=False,
3395            strict=False,
3396            name="case"):
3397  """Create a case operation.
3398
3399  See also `tf.switch_case`.
3400
3401  The `pred_fn_pairs` parameter is a list of pairs of size N.
3402  Each pair contains a boolean scalar tensor and a python callable that
3403  creates the tensors to be returned if the boolean evaluates to True.
3404  `default` is a callable generating a list of tensors. All the callables
3405  in `pred_fn_pairs` as well as `default` (if provided) should return the same
3406  number and types of tensors.
3407
3408  If `exclusive==True`, all predicates are evaluated, and an exception is
3409  thrown if more than one of the predicates evaluates to `True`.
3410  If `exclusive==False`, execution stops at the first predicate which
3411  evaluates to True, and the tensors generated by the corresponding function
3412  are returned immediately. If none of the predicates evaluate to True, this
3413  operation returns the tensors generated by `default`.
3414
3415  `tf.case` supports nested structures as implemented in
3416  `tf.nest`. All of the callables must return the same (possibly nested) value
3417  structure of lists, tuples, and/or named tuples. Singleton lists and tuples
3418  form the only exceptions to this: when returned by a callable, they are
3419  implicitly unpacked to single values. This behavior is disabled by passing
3420  `strict=True`.
3421
3422  @compatibility(v2)
3423  `pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and
3424  tf.Variable are no longer hashable in v2, so cannot be used as a key for a
3425  dictionary.  Please use a list or a tuple instead.
3426  @end_compatibility
3427
3428
3429  **Example 1:**
3430
3431  Pseudocode:
3432
3433  ```
3434  if (x < y) return 17;
3435  else return 23;
3436  ```
3437
3438  Expressions:
3439
3440  ```python
3441  f1 = lambda: tf.constant(17)
3442  f2 = lambda: tf.constant(23)
3443  r = tf.case([(tf.less(x, y), f1)], default=f2)
3444  ```
3445
3446  **Example 2:**
3447
3448  Pseudocode:
3449
3450  ```
3451  if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
3452  if (x < y) return 17;
3453  else if (x > z) return 23;
3454  else return -1;
3455  ```
3456
3457  Expressions:
3458
3459  ```python
3460  def f1(): return tf.constant(17)
3461  def f2(): return tf.constant(23)
3462  def f3(): return tf.constant(-1)
3463  r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)],
3464           default=f3, exclusive=True)
3465  ```
3466
3467  Args:
3468    pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which
3469      returns a list of tensors.
3470    default: Optional callable that returns a list of tensors.
3471    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3472    strict: A boolean that enables/disables 'strict' mode; see above.
3473    name: A name for this operation (optional).
3474
3475  Returns:
3476    The tensors returned by the first pair whose predicate evaluated to True, or
3477    those returned by `default` if none does.
3478
3479  Raises:
3480    TypeError: If `pred_fn_pairs` is not a list/tuple.
3481    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3482    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3483               callable.
3484  """
3485  return _case_helper(
3486      cond,
3487      pred_fn_pairs,
3488      default,
3489      exclusive,
3490      name,
3491      allow_python_preds=False,
3492      strict=strict)
3493
3494
3495@tf_export(v1=["case"])
3496@dispatch.add_dispatch_support
3497def case(pred_fn_pairs,
3498         default=None,
3499         exclusive=False,
3500         strict=False,
3501         name="case"):
3502  """Create a case operation.
3503
3504  See also `tf.switch_case`.
3505
3506  The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
3507  Each pair contains a boolean scalar tensor and a python callable that
3508  creates the tensors to be returned if the boolean evaluates to True.
3509  `default` is a callable generating a list of tensors. All the callables
3510  in `pred_fn_pairs` as well as `default` (if provided) should return the same
3511  number and types of tensors.
3512
3513  If `exclusive==True`, all predicates are evaluated, and an exception is
3514  thrown if more than one of the predicates evaluates to `True`.
3515  If `exclusive==False`, execution stops at the first predicate which
3516  evaluates to True, and the tensors generated by the corresponding function
3517  are returned immediately. If none of the predicates evaluate to True, this
3518  operation returns the tensors generated by `default`.
3519
3520  `tf.case` supports nested structures as implemented in
3521  `tf.nest`. All of the callables must return the same (possibly nested) value
3522  structure of lists, tuples, and/or named tuples. Singleton lists and tuples
3523  form the only exceptions to this: when returned by a callable, they are
3524  implicitly unpacked to single values. This behavior is disabled by passing
3525  `strict=True`.
3526
3527  If an unordered dictionary is used for `pred_fn_pairs`, the order of the
3528  conditional tests is not guaranteed. However, the order is guaranteed to be
3529  deterministic, so that variables created in conditional branches are created
3530  in fixed order across runs.
3531
3532  @compatibility(eager)
3533  Unordered dictionaries are not supported in eager mode when `exclusive=False`.
3534  Use a list of tuples instead.
3535  @end_compatibility
3536
3537
3538  **Example 1:**
3539
3540  Pseudocode:
3541
3542  ```
3543  if (x < y) return 17;
3544  else return 23;
3545  ```
3546
3547  Expressions:
3548
3549  ```python
3550  f1 = lambda: tf.constant(17)
3551  f2 = lambda: tf.constant(23)
3552  r = tf.case([(tf.less(x, y), f1)], default=f2)
3553  ```
3554
3555  **Example 2:**
3556
3557  Pseudocode:
3558
3559  ```
3560  if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
3561  if (x < y) return 17;
3562  else if (x > z) return 23;
3563  else return -1;
3564  ```
3565
3566  Expressions:
3567
3568  ```python
3569  def f1(): return tf.constant(17)
3570  def f2(): return tf.constant(23)
3571  def f3(): return tf.constant(-1)
3572  r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
3573           default=f3, exclusive=True)
3574  ```
3575
3576  Args:
3577    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
3578      callable which returns a list of tensors.
3579    default: Optional callable that returns a list of tensors.
3580    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3581    strict: A boolean that enables/disables 'strict' mode; see above.
3582    name: A name for this operation (optional).
3583
3584  Returns:
3585    The tensors returned by the first pair whose predicate evaluated to True, or
3586    those returned by `default` if none does.
3587
3588  Raises:
3589    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3590    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3591    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3592               callable.
3593  """
3594  return _case_helper(
3595      cond,
3596      pred_fn_pairs,
3597      default,
3598      exclusive,
3599      name,
3600      allow_python_preds=False,
3601      strict=strict)
3602
3603
3604@tf_export("switch_case")
3605def switch_case(branch_index,
3606                branch_fns,
3607                default=None,
3608                name="switch_case"):
3609  """Create a switch/case operation, i.e. an integer-indexed conditional.
3610
3611  See also `tf.case`.
3612
3613  This op can be substantially more efficient than `tf.case` when exactly one
3614  branch will be selected. `tf.switch_case` is more like a C++ switch/case
3615  statement than `tf.case`, which is more like an if/elif/elif/else chain.
3616
3617  The `branch_fns` parameter is either a dict from `int` to callables, or list
3618  of (`int`, callable) pairs, or simply a list of callables (in which case the
3619  index is implicitly the key). The `branch_index` `Tensor` is used to select an
3620  element in `branch_fns` with matching `int` key, falling back to `default`
3621  if none match, or `max(keys)` if no `default` is provided. The keys must form
3622  a contiguous set from `0` to `len(branch_fns) - 1`.
3623
3624  `tf.switch_case` supports nested structures as implemented in `tf.nest`. All
3625  callables must return the same (possibly nested) value structure of lists,
3626  tuples, and/or named tuples.
3627
3628  **Example:**
3629
3630  Pseudocode:
3631
3632  ```c++
3633  switch (branch_index) {  // c-style switch
3634    case 0: return 17;
3635    case 1: return 31;
3636    default: return -1;
3637  }
3638  ```
3639  or
3640  ```python
3641  branches = {0: lambda: 17, 1: lambda: 31}
3642  branches.get(branch_index, lambda: -1)()
3643  ```
3644
3645  Expressions:
3646
3647  ```python
3648  def f1(): return tf.constant(17)
3649  def f2(): return tf.constant(31)
3650  def f3(): return tf.constant(-1)
3651  r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3)
3652  # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3})
3653  ```
3654
3655  Args:
3656    branch_index: An int Tensor specifying which of `branch_fns` should be
3657      executed.
3658    branch_fns: A `dict` mapping `int`s to callables, or a `list` of
3659      (`int`, callable) pairs, or simply a list of callables (in which case the
3660      index serves as the key). Each callable must return a matching structure
3661      of tensors.
3662    default: Optional callable that returns a structure of tensors.
3663    name: A name for this operation (optional).
3664
3665  Returns:
3666    The tensors returned by the callable identified by `branch_index`, or those
3667    returned by `default` if no key matches and `default` was provided, or those
3668    returned by the max-keyed `branch_fn` if no `default` is provided.
3669
3670  Raises:
3671    TypeError: If `branch_fns` is not a list/dictionary.
3672    TypeError: If `branch_fns` is a list but does not contain 2-tuples or
3673               callables.
3674    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3675               callable.
3676  """
3677  return _indexed_case_helper(branch_fns, default, branch_index, name)
3678
3679
3680@tf_export("__internal__.execute_fn_for_device", v1=[])
3681def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"):
3682  """Executes one of the provided callables based on the device placement.
3683
3684  This API is used when the implementations for high level function depend on
3685  the underlying device placement. It takes a dictionary of device type to
3686  callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of
3687  the device where to run this op matches the key in 'device_branch_fns',
3688  the corresponding callable is executed, falling back to 'default_fn' if none
3689  matches.
3690
3691  **Example:**
3692  ```python
3693  def f1(): return tf.constant(1)
3694  def f2(): return tf.constant(2)
3695  r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1)
3696  ```
3697  'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on
3698  any other device types.
3699
3700
3701  Args:
3702    device_branch_fns: a dictionary of device types to the callables. Each
3703      callable must return a matching structure of tensors.
3704    default_fn: fallback callable when the underlying device does not match any
3705      key in the 'device_branch_fns'.
3706    name: A name for this operation (optional).
3707
3708  Returns:
3709    The tensors returned by the callable identified by device type during
3710    execution, or those returned by 'default_fn' if no key matches.
3711  """
3712  # Always execute the default fn for XLA to avoid complicated graph by case op.
3713  # see more discussions in b/167276293.
3714  is_in_xla = util.GraphOrParentsInXlaContext(ops.get_default_graph())
3715  if is_in_xla:
3716    return default_fn()
3717  device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()}
3718  branch_fns = list(device_branch_fns_upper.values())
3719  devices = list(device_branch_fns_upper.keys())
3720  device_index = gen_functional_ops.device_index(device_names=devices)
3721  return _indexed_case_helper(
3722      branch_fns,
3723      default_fn,
3724      device_index,
3725      name,
3726      lower_using_switch_merge=False)
3727
3728
3729class XLAControlFlowContext(ControlFlowContext):
3730  """Base class for XLA and TPU control flow contexts."""
3731
3732  def __init__(self):
3733    super(XLAControlFlowContext, self).__init__()
3734    self._name = "XLAControlFlowContext"
3735
3736  def to_control_flow_context_def(self, context_def, export_scope=None):
3737    # pylint: disable=useless-super-delegation
3738    # NOTE(slebedev): the method is required by `ControlFlowContext`.
3739    super(XLAControlFlowContext,
3740          self).to_control_flow_context_def(context_def, export_scope)
3741
3742  def IsXLAContext(self):
3743    return True
3744
3745  def AddOp(self, _):
3746    pass
3747
3748  def AddValue(self, x):
3749    return x
3750
3751  def RequiresUniqueFunctionRetracing(self):
3752    """Returns whether the tf.function should be retraced if the context changes.
3753    """
3754    return False
3755
3756
3757@tf_export("__internal__.get_enclosing_xla_context", v1=[])
3758def get_enclosing_xla_context():
3759  """Recursively find and return the XLAControlFlowContext."""
3760  graph = ops.get_default_graph()
3761  while graph is not None:
3762    # pylint: disable=protected-access
3763    context_ = graph._get_control_flow_context()
3764    # pylint: enable=protected-access
3765    while context_ is not None:
3766      if isinstance(context_, XLAControlFlowContext):
3767        return context_
3768      context_ = context_.outer_context
3769    # This may be a FuncGraph due to defuns or v2 control flow. We need to
3770    # find the original graph with the XLAControlFlowContext.
3771    graph = getattr(graph, "outer_graph", None)
3772  return None
3773
3774
3775def from_control_flow_context_def(context_def, import_scope=None):
3776  """Deserializes `context_def` into the appropriate ControlFlowContext.
3777
3778  Args:
3779    context_def: ControlFlowContextDef proto
3780    import_scope: Optional `string`. Name scope to add.
3781
3782  Returns:
3783    A ControlFlowContext subclass
3784  """
3785  if context_def.HasField("cond_ctxt"):
3786    return CondContext.from_proto(
3787        context_def.cond_ctxt, import_scope=import_scope)
3788  if context_def.HasField("while_ctxt"):
3789    return WhileContext.from_proto(
3790        context_def.while_ctxt, import_scope=import_scope)
3791  raise NotImplementedError("Unknown ControlFlowContextDef field: %s" %
3792                            context_def.WhichOneof("ctxt"))
3793
3794
3795ops.register_proto_function(
3796    ops.GraphKeys.COND_CONTEXT,
3797    proto_type=control_flow_pb2.CondContextDef,
3798    to_proto=CondContext.to_proto,
3799    from_proto=CondContext.from_proto)
3800
3801ops.register_proto_function(
3802    ops.GraphKeys.WHILE_CONTEXT,
3803    proto_type=control_flow_pb2.WhileContextDef,
3804    to_proto=WhileContext.to_proto,
3805    from_proto=WhileContext.from_proto)
3806