• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Control Flow Operations.
16
17See the [autograph](https://www.tensorflow.org/guide/autograph) guide.
18"""
19# pylint: disable=g-bad-name
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import abc
25import collections
26import functools
27
28import six
29
30from tensorflow.core.framework import attr_value_pb2
31from tensorflow.core.protobuf import control_flow_pb2
32from tensorflow.python.eager import context
33from tensorflow.python.framework import composite_tensor
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.framework import tensor_spec
40from tensorflow.python.framework import tensor_util
41from tensorflow.python.framework import type_spec
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import control_flow_util as util
44from tensorflow.python.ops import gen_array_ops
45from tensorflow.python.ops import gen_control_flow_ops
46from tensorflow.python.ops import gen_logging_ops
47from tensorflow.python.ops import gen_math_ops
48from tensorflow.python.ops import math_ops
49from tensorflow.python.ops import tensor_array_ops
50# go/tf-wildcard-import
51# pylint: disable=wildcard-import,undefined-variable
52from tensorflow.python.ops.gen_control_flow_ops import *
53# pylint: enable=wildcard-import
54from tensorflow.python.platform import tf_logging as logging
55from tensorflow.python.util import compat
56from tensorflow.python.util import deprecation
57from tensorflow.python.util import nest
58from tensorflow.python.util import tf_should_use
59from tensorflow.python.util.lazy_loader import LazyLoader
60from tensorflow.python.util.tf_export import tf_export
61
62# This is to avoid a circular dependency:
63# cond_v2 -> gradients_util -> control_flow_ops
64cond_v2 = LazyLoader("cond_v2", globals(),
65                     "tensorflow.python.ops.cond_v2")
66
67# This is to avoid circular dependencies:
68# while_v2 -> control_flow_ops
69# while_v2 -> gradients_util -> control_flow_ops
70while_v2 = LazyLoader("while_v2", globals(),
71                      "tensorflow.python.ops.while_v2")
72
73# We override the 'tuple' for a control flow op, so we keep python's
74# existing 'tuple' for later use in this module.
75_basetuple = tuple
76
77
78def _summarize_eager(tensor, summarize=None):
79  """Returns a summarized string representation of eager `tensor`.
80
81  Args:
82    tensor: EagerTensor to summarize
83    summarize: Include these many first elements of `array`
84  """
85  # Emulate the behavior of Tensor::SummarizeValue()
86  if summarize is None:
87    summarize = 3
88  elif summarize < 0:
89    summarize = array_ops.size(tensor)
90
91  # reshape((-1,)) is the fastest way to get a flat array view
92  if tensor._rank():  # pylint: disable=protected-access
93    flat = tensor.numpy().reshape((-1,))
94    lst = [str(x) for x in flat[:summarize]]
95    if len(lst) < flat.size:
96      lst.append("...")
97  else:
98    # tensor.numpy() returns a scalar for zero dimensional arrays
99    if gen_math_ops.not_equal(summarize, 0):
100      lst = [str(tensor.numpy())]
101    else:
102      lst = []
103
104  return ", ".join(lst)
105
106
107# pylint: disable=protected-access
108
109
110# Assert and Print are special symbols in python, so we must
111# use an upper-case version of them.
112@tf_export("debugging.Assert", "Assert")
113@tf_should_use.should_use_result
114def Assert(condition, data, summarize=None, name=None):
115  """Asserts that the given condition is true.
116
117  If `condition` evaluates to false, print the list of tensors in `data`.
118  `summarize` determines how many entries of the tensors to print.
119
120  NOTE: In graph mode, to ensure that Assert executes, one usually attaches
121  a dependency:
122
123  ```python
124  # Ensure maximum element of x is smaller or equal to 1
125  assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])
126  with tf.control_dependencies([assert_op]):
127    ... code using x ...
128  ```
129
130  Args:
131    condition: The condition to evaluate.
132    data: The tensors to print out when condition is false.
133    summarize: Print this many entries of each tensor.
134    name: A name for this operation (optional).
135
136  Returns:
137    assert_op: An `Operation` that, when executed, raises a
138    `tf.errors.InvalidArgumentError` if `condition` is not true.
139    @compatibility(eager)
140    returns None
141    @end_compatibility
142
143  Raises:
144    @compatibility(eager)
145    `tf.errors.InvalidArgumentError` if `condition` is not true
146    @end_compatibility
147  """
148  if context.executing_eagerly():
149    if not condition:
150      xs = ops.convert_n_to_tensor(data)
151      data_str = [_summarize_eager(x, summarize) for x in xs]
152      raise errors.InvalidArgumentError(
153          node_def=None,
154          op=None,
155          message="Expected '%s' to be true. Summarized data: %s" %
156          (condition, "\n".join(data_str)))
157    return
158
159  with ops.name_scope(name, "Assert", [condition, data]) as name:
160    xs = ops.convert_n_to_tensor(data)
161    if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs):
162      # As a simple heuristic, we assume that string and int32 are
163      # on host to avoid the need to use cond. If it is not case,
164      # we will pay the price copying the tensor to host memory.
165      return gen_logging_ops._assert(condition, data, summarize, name="Assert")
166    else:
167      condition = ops.convert_to_tensor(condition, name="Condition")
168
169      def true_assert():
170        return gen_logging_ops._assert(
171            condition, data, summarize, name="Assert")
172
173      guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
174      if context.executing_eagerly():
175        return
176      return guarded_assert.op
177
178
179def _Identity(data, name=None):
180  """Return a tensor with the same shape and contents as the input tensor.
181
182  Args:
183    data: A Tensor.
184    name: A name for this operation (optional).
185
186  Returns:
187    A Tensor with the same type and value as the input Tensor.
188  """
189  data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
190  if isinstance(data, ops.Tensor):
191    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
192      return gen_array_ops.ref_identity(data, name=name)
193    else:
194      return array_ops.identity(data, name=name)
195  elif isinstance(data, composite_tensor.CompositeTensor):
196    return nest.map_structure(_Identity, data, expand_composites=True)
197  else:
198    raise TypeError("Type %s not supported" % type(data))
199
200
201def _NextIteration(data, name=None):
202  data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
203  if isinstance(data, ops.Tensor):
204    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
205      return ref_next_iteration(data, name=name)
206    else:
207      return next_iteration(data, name=name)
208  elif isinstance(data, composite_tensor.CompositeTensor):
209    return nest.map_structure(_NextIteration, data, expand_composites=True)
210  else:
211    raise TypeError("Type %s not supported" % type(data))
212
213
214def _Enter(data,
215           frame_name,
216           is_constant=False,
217           parallel_iterations=10,
218           use_ref=True,
219           use_input_shape=True,
220           name=None):
221  """Creates or finds a child frame, and makes `data` available to it.
222
223  The unique `frame_name` is used by the `Executor` to identify frames. If
224  `is_constant` is true, `data` is a constant in the child frame; otherwise
225  it may be changed in the child frame. At most `parallel_iterations`
226  iterations are run in parallel in the child frame.
227
228  Args:
229    data: The tensor to be made available to the child frame.
230    frame_name: The name of the child frame.
231    is_constant: If true, the output is constant within the child frame.
232    parallel_iterations: The number of iterations allowed to run in parallel.
233    use_ref: If true, use ref_enter if data is of ref type.
234    use_input_shape: If true, set the result's shape based on data's shape.
235    name: A name for this operation (optional).
236
237  Returns:
238    The same tensor as `data`.
239  """
240  data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
241  if isinstance(data, ops.Tensor):
242    if data.dtype._is_ref_dtype and use_ref:  # pylint: disable=protected-access
243      result = gen_control_flow_ops.ref_enter(
244          data, frame_name, is_constant, parallel_iterations, name=name)
245    else:
246      result = gen_control_flow_ops.enter(
247          data, frame_name, is_constant, parallel_iterations, name=name)
248    if use_input_shape:
249      result.set_shape(data.get_shape())
250    return result
251  elif isinstance(data, composite_tensor.CompositeTensor):
252
253    def enter_component(t):
254      return _Enter(t, frame_name, is_constant, parallel_iterations, use_ref,
255                    use_input_shape)
256
257    return nest.map_structure(enter_component, data, expand_composites=True)
258  else:
259    raise TypeError("Type %s not supported" % type(data))
260
261
262def exit(data, name=None):  # pylint: disable=redefined-builtin
263  """Exits the current frame to its parent frame.
264
265  Exit makes its input `data` available to the parent frame.
266
267  Args:
268    data: The tensor to be made available to the parent frame.
269    name: A name for this operation (optional).
270
271  Returns:
272    The same tensor as `data`.
273  """
274  data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
275  if isinstance(data, ops.Tensor):
276    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
277      return gen_control_flow_ops.ref_exit(data, name)
278    else:
279      return gen_control_flow_ops._exit(data, name)
280  elif isinstance(data, composite_tensor.CompositeTensor):
281    return nest.map_structure(exit, data, expand_composites=True)
282  else:
283    raise TypeError("Type %s not supported" % type(data))
284
285
286def switch(data, pred, dtype=None, name=None):
287  """Forwards `data` to an output determined by `pred`.
288
289  If `pred` is false, the `data` input is forwarded to the first output.
290  Otherwise, the data goes to the second output.
291
292  This op handles `Tensor`s and `IndexedSlices`.
293
294  Args:
295    data: The tensor to be forwarded to the appropriate output.
296    pred: A scalar that specifies which output port will receive data.
297    dtype: Optional element type for the returned tensor. If missing, the type
298      is inferred from the type of `value`.
299    name: A name for this operation (optional).
300
301  Returns:
302    `(output_false, output_true)`: If `pred` is true, data will be forwarded
303    to `output_true`, otherwise it goes to `output_false`.
304  """
305  with ops.name_scope(name, "Switch", [data, pred]) as name:
306    data = ops.internal_convert_to_tensor_or_composite(
307        data, dtype=dtype, name="data", as_ref=True)
308    pred = ops.convert_to_tensor(pred, name="pred")
309    if isinstance(data, ops.Tensor):
310      return gen_control_flow_ops.switch(data, pred, name=name)
311    else:
312      if not isinstance(data, composite_tensor.CompositeTensor):
313        raise TypeError("Type %s not supported" % type(data))
314      tensors = nest.flatten(data, expand_composites=True)
315      mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors]
316      mapped_f, mapped_t = zip(*mapped)
317      return (nest.pack_sequence_as(data, mapped_f, expand_composites=True),
318              nest.pack_sequence_as(data, mapped_t, expand_composites=True))
319
320
321def _SwitchRefOrTensor(data, pred, name="Switch"):
322  """Forwards `data` to an output determined by `pred`.
323
324  If `pred` is false, the `data` input is forwarded to the first output.
325  Otherwise, the data goes to the second output.
326
327  This op handles `Tensor`s and `IndexedSlices`.
328
329  Args:
330    data: The tensor to be forwarded to the appropriate output.
331    pred: A scalar that specifies which output port will receive data.
332    name: A name for this operation (optional).
333
334  Returns:
335    `(output_false, output_true)`: If `pred` is true, data will be forwarded to
336    `output_true`, otherwise it goes to `output_false`.
337
338  Raises:
339    TypeError: if data is not a Tensor or IndexedSlices
340  """
341  data = ops.convert_to_tensor_or_composite(data, name="data")
342  # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
343  # addresses the following scenario.
344  #
345  # Assume you execute Optimizer.apply_gradients() in a branch of a cond().
346  #
347  # 1. The update op is created inside a `with ops.colocate(var):` block
348  #
349  # 2. Some tensor `data` is captured and a switch is created in a
350  #    `with ops.colocate_with(data):` block.
351  #
352  # with ops.colocate_with(var):
353  #  with ops.colocate_with(data):
354  #    op = ...
355  #
356  # var and data may be pinned to different devices, so we want to ops
357  # created within ops.colocate_with(data) to ignore the existing stack.
358  with ops.colocate_with(data, ignore_existing=True):
359    if isinstance(data, ops.Tensor):
360      if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
361        return ref_switch(data, pred, name=name)
362    return switch(data, pred, name=name)
363
364
365def merge(inputs, name=None):
366  """Returns the value of an available element of `inputs`.
367
368  This op tests each of the tensors in `inputs` in turn to determine if any of
369  them is available. If it finds an available tensor, it returns it and its
370  index in `inputs`.
371
372  It is an error if more than one tensor in `inputs` is available. If no tensor
373  in `inputs` is available, the returned tensor and index are not set.
374
375  This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
376  `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
377  before merging.
378
379  Args:
380    inputs: The input tensors, at most one of which is available.
381    name: A name for this operation (optional).
382
383  Returns:
384    A tuple containing the chosen input tensor and its index in `inputs`.
385
386  Raises:
387    ValueError: If any of the inputs is None, or inputs are IndexedSlices and
388      some but not all have a dense_shape property.
389  """
390  if any(inp is None for inp in inputs):
391    raise ValueError("At least one of the merge inputs is None: %s" % inputs)
392  with ops.name_scope(name, "Merge", inputs) as name:
393    inputs = [
394        ops.internal_convert_to_tensor_or_composite(inp, as_ref=True)
395        for inp in inputs
396    ]
397    if all(isinstance(v, ops.Tensor) for v in inputs):
398      if all(v.dtype._is_ref_dtype for v in inputs):  # pylint: disable=protected-access
399        return gen_control_flow_ops.ref_merge(inputs, name)
400      else:
401        return gen_control_flow_ops.merge(inputs, name)
402    else:
403      # If there is a mix of tensors and indexed slices, then convert the
404      # tensors to indexed slices.
405      if all(isinstance(v, (ops.IndexedSlices, ops.Tensor)) for v in inputs):
406        inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
407
408      for v in inputs:
409        if not isinstance(v, composite_tensor.CompositeTensor):
410          raise TypeError("Type %s not supported" % type(v))
411
412      for v in inputs[1:]:
413        nest.assert_same_structure(inputs[0], v, expand_composites=True)
414
415      flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs]
416      merged_results = [
417          gen_control_flow_ops.merge(component)
418          for component in zip(*flat_inputs)
419      ]
420      flat_merged = [tensor for (tensor, _) in merged_results]
421      chosen_index = merged_results[0][1]
422      merged_inputs = nest.pack_sequence_as(
423          inputs[0], flat_merged, expand_composites=True)
424      return (merged_inputs, chosen_index)
425
426
427# pylint: enable=protected-access
428
429
430def _convert_tensorarray_to_flow(tensor_or_tensor_array):
431  if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
432    return tensor_or_tensor_array.flow
433  else:
434    return tensor_or_tensor_array
435
436
437def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
438  if len(tensors_or_tensorarrays) != len(tensors_or_flows):
439    raise ValueError(
440        "Lengths of original Tensor list and new list do not match: %d vs. %d" %
441        (len(tensors_or_tensorarrays), len(tensors_or_flows)))
442  return [
443      tensor_array_ops.build_ta_with_new_flow(ta, t_or_flow) if isinstance(
444          ta, tensor_array_ops.TensorArray) else t_or_flow
445      for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
446  ]
447
448
449def _ShapeLessThanOrEqual(shape1, shape2):
450  if shape2.dims is None:
451    return True
452  if shape1.ndims != shape2.ndims:
453    return False
454  for dim1, dim2 in zip(shape1.dims, shape2.dims):
455    if dim2.value is not None and dim1.value != dim2.value:
456      return False
457  return True
458
459
460def _get_shape_invariant(var, shape=None):
461  """Returns shape invariant(s) for the given variable.
462
463  Args:
464    var: The tensor whose shape is described.
465    shape: The shape invariant for the tensor.  If not specified, then a default
466      shape invariant for `var` is returned.
467
468  Returns:
469    `TensorShape` or `list` of `TensorShape`: The shape invariant for `var` (if
470    it is a `Tensor`), or the shape invariants for the components that comprise
471    `var` (if it is a `CompositeTensor`).
472  """
473  if isinstance(var, composite_tensor.CompositeTensor):
474    # Get a TypeSpec for `var`.
475    if shape is None:
476      spec = var._type_spec  # pylint: disable=protected-access
477    else:
478      spec = _shape_invariant_to_type_spec(var, shape)
479
480    tensor_specs = nest.flatten(spec, expand_composites=True)
481    return [tspec.shape for tspec in tensor_specs]
482
483  elif shape is None:
484    return var.shape
485  elif isinstance(shape, tensor_spec.TensorSpec):
486    if var.dtype != shape.dtype:
487      raise TypeError("TensorSpec %r is not compatible with %r" % (shape, var))
488    return shape.shape
489  elif isinstance(shape, type_spec.TypeSpec):
490    raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
491  else:
492    return shape
493
494
495def _shape_invariant_to_type_spec(var, shape):
496  """Converts a shape invariant to a TypeSpec.
497
498  Args:
499    var: The tensor whose shape is described by the shape invariant.
500    shape: A `TypeSpec` or `TensorShape`.  If `shape` is already a `TypeSpec`,
501      then it is simply returned as-is.
502
503  Returns:
504    A `TypeSpec` for `var`, consistent with the given shape.
505  """
506  if shape is None:
507    return type_spec.type_spec_from_value(var)
508  elif isinstance(shape, type_spec.TypeSpec):
509    if not shape.is_compatible_with(var):
510      raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
511    return shape
512  elif not isinstance(shape, tensor_shape.TensorShape):
513    raise TypeError(
514        "Expected shape to be a TypeSpec, TensorShape or None, got %r for"
515        " value %r" % (shape, var))
516
517  if isinstance(var, ops.Tensor):
518    return tensor_spec.TensorSpec(shape, var.dtype)
519
520  elif isinstance(var, composite_tensor.CompositeTensor):
521    try:
522      return var._shape_invariant_to_type_spec(shape)  # pylint: disable=protected-access
523    except NotImplementedError:
524      raise TypeError(
525          "To describe or constrain a %s, use a %s instead of a TensorShape." %
526          (type(var).__name__, type(var._type_spec).__name__))  # pylint: disable=protected-access
527
528  else:
529    raise TypeError("Expected var to be a Tensor or CompositeTensor, got %s"
530                    % var)
531
532
533def _SetShapeInvariants(input_vars, enter_vars, shapes):
534  """Set the shapes of the tensors in `enter_vars` to `shapes`.
535
536  Args:
537    input_vars: A list of tensors that are inputs to `enter_vars`.
538    enter_vars: A list of tensors whose shapes will be set.
539    shapes: A (possibly nested) list of shapes.
540
541  Raises:
542    ValueError: If any tensor in `enter_vars` has a less specific shape
543      than its corresponding shape in `shapes`.
544  """
545  if shapes is None:
546    return
547  flat_shapes = nest.flatten(shapes)
548  if not all(isinstance(s, tensor_shape.TensorShape) for s in flat_shapes):
549    raise ValueError("`shapes` must be a (possibly nested) list of shapes.")
550  # Check that the shapes of the inputs are less than the shape invariants,
551  # and set the shapes of `enter_vars` to the shape invariants.
552  for inp, var, shape in zip(input_vars, enter_vars, flat_shapes):
553    if isinstance(var, ops.Tensor):
554      if not _ShapeLessThanOrEqual(inp.get_shape(), shape):
555        raise ValueError(
556            "The shape invariant specified for %s is not compatible with "
557            "the initial shape of the loop variable. It enters the loop "
558            "with shape %s, but the specified shape invariant is %s." %
559            (inp.name, inp.get_shape(), shape))
560      var.set_shape(shape)
561    else:
562      raise TypeError("Type %s not supported" % type(var))
563
564
565def _EnforceShapeInvariant(merge_var, next_var):
566  """Check if the shapes of the loops variables are invariants.
567
568  Args:
569    merge_var: The list of tensors representing the initial values of the loop
570      variables.
571    next_var: The list of tensors representing the values of the loop variables
572      after one loop iteration.
573
574  Raises:
575    ValueError: If any tensor in `merge_var` has a more specific shape than
576      its correspnding tensor in `next_var`.
577  """
578  if isinstance(merge_var, ops.Tensor):
579    m_shape = merge_var.get_shape()
580    n_shape = next_var.get_shape()
581    if not _ShapeLessThanOrEqual(n_shape, m_shape):
582      enter = merge_var.op.inputs[0].op
583      assert util.IsLoopEnter(enter)
584      input_t = enter.inputs[0]
585      raise ValueError(
586          "Input tensor '%s' enters the loop with shape %s, but has shape %s "
587          "after one iteration. To allow the shape to vary across iterations, "
588          "use the `shape_invariants` argument of tf.while_loop to specify a "
589          "less-specific shape." % (input_t.name, input_t.shape, n_shape))
590  else:
591    raise TypeError("Type %s not supported" % type(merge_var))
592
593
594def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
595  """Add NextIteration and back edge from v to m."""
596  if isinstance(m, ops.Tensor):
597    v = ops.convert_to_tensor(v)
598    v = _NextIteration(v)
599    if enforce_shape_invariant:
600      # Make sure the shapes of loop outputs are correct. We do this before
601      # calling _update_input, which will raise a less-helpful error message if
602      # the types don't match.
603      # TODO(skyewm): call this for other cases below (needs testing)
604      _EnforceShapeInvariant(m, v)
605    m.op._update_input(1, v)  # pylint: disable=protected-access
606  elif isinstance(m, composite_tensor.CompositeTensor):
607    # pylint: disable=protected-access
608    def update_component(m_component, v_component):
609      m_component.op._update_input(1, v_component)
610
611    if isinstance(m, ops.IndexedSlices):
612      v = math_ops._as_indexed_slices(v, optimize=False)
613    # pylint: enable=protected-access
614    v = _NextIteration(v)
615    return nest.map_structure(update_component, m, v, expand_composites=True)
616  else:
617    raise TypeError("Type %s not supported" % type(m))
618  return v
619
620
621@six.add_metaclass(abc.ABCMeta)
622class ControlFlowContext(object):
623  """The base class for control flow context.
624
625  The usage pattern is a sequence of (Enter, Exit) followed by a final
626  ExitResult.
627
628  We maintain the following state for control flow contexts during graph
629  construction:
630   1. graph has _control_flow_context: the current context used to
631      construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()
632   2. op has _control_flow_context: the context to which the op belongs.
633      Set at the time the op is created. Immutable.
634   3. A ControlFlowContext has _outer_context: the context in which this
635      context is created. Set at the time a context is created. Immutable.
636   4. A ControlFlowContext has _context_stack.
637      Pushed and popped by ctxt.Enter() and ctxt.Exit()
638  """
639
640  def __init__(self, values_def=None, import_scope=None):
641    self._nested_contexts = []
642    self._outer_context = ops.get_default_graph()._get_control_flow_context()
643    if self._outer_context:
644      self._outer_context._nested_contexts.append(self)  # pylint: disable=protected-access
645    self._context_stack = []
646    if values_def:
647      self._init_values_from_proto(values_def, import_scope=import_scope)
648    else:
649      # The names of tensors that have been already seen in this context.
650      self._values = set()
651      # The keys are the names of tensors referenced by but external to this
652      # context. Each value is the Tensor that should be used by this context to
653      # access the key value (e.g. a switch output guarding a cond input value).
654      self._external_values = {}
655
656  def _init_values_from_proto(self, values_def, import_scope=None):
657    """Initializes values and external_values from `ValuesDef` protocol buffer.
658
659    Args:
660      values_def: `ValuesDef` protocol buffer.
661      import_scope: Optional `string`. Name scope to add.
662    """
663    assert isinstance(values_def, control_flow_pb2.ValuesDef)
664    self._values = set(
665        ops.prepend_name_scope(value, import_scope)
666        for value in values_def.values)
667    g = ops.get_default_graph()
668    self._external_values = {}
669    for k, v in values_def.external_values.items():
670      k = ops.prepend_name_scope(k, import_scope)
671      self._external_values[k] = g.as_graph_element(
672          ops.prepend_name_scope(v, import_scope))
673    op_names = set([
674        op.split(":")[0]
675        for op in self._values - set(self._external_values.keys())
676    ])
677    for op in op_names:
678      # pylint: disable=protected-access
679      g.as_graph_element(op)._set_control_flow_context(self)
680      # pylint: enable=protected-access
681
682  @property
683  def name(self):
684    return self._name
685
686  @property
687  def outer_context(self):
688    """Return the context containing this context."""
689    return self._outer_context
690
691  @property
692  def grad_state(self):
693    raise NotImplementedError("Abstract method")
694
695  @property
696  def back_prop(self):
697    raise NotImplementedError("Abstract method")
698
699  @abc.abstractmethod
700  def to_control_flow_context_def(self, context_def, export_scope=None):
701    """Serializes this into `context_def`.
702
703    Args:
704      context_def: a `ControlFlowContextDef` protocol buffer.
705      export_scope: Optional `string`. Name scope to remove.
706    """
707    raise NotImplementedError("Abstract method")
708
709  def _to_values_def(self, export_scope=None):
710    """Converts the values to a `ValuesDef` protocol buffer.
711
712    Args:
713      export_scope: Optional `string`. Name scope to remove.
714
715    Returns:
716      A `ValuesDef` protocol buffer.
717    """
718    values_def = control_flow_pb2.ValuesDef()
719    values_def.values.extend(
720        [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
721    for k, v in self._external_values.items():
722      k = ops.strip_name_scope(k, export_scope)
723      values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
724    return values_def
725
726  def AddName(self, name):
727    self._values.add(name)
728
729  # pylint: disable=protected-access
730  def Enter(self):
731    """Enter this control flow context."""
732    graph = ops.get_default_graph()
733    self._context_stack.append(graph._get_control_flow_context())
734    graph._set_control_flow_context(self)
735
736  def Exit(self):
737    """Exit this control flow context."""
738    graph = ops.get_default_graph()
739    last_context = self._context_stack.pop()
740    graph._set_control_flow_context(last_context)
741
742  def EnterGradientColocation(self, op, gradient_uid):
743    """Start building a gradient colocated with an op."""
744    if self._outer_context:
745      self._outer_context.EnterGradientColocation(op, gradient_uid)
746
747  def ExitGradientColocation(self, op, gradient_uid):
748    """Start building a gradient colocated with an op."""
749    if self._outer_context:
750      self._outer_context.ExitGradientColocation(op, gradient_uid)
751
752  def ExitResult(self, result):
753    """Make a list of tensors available in the outer context."""
754    if self._outer_context:
755      def fn(x):
756        self._outer_context.AddName(x.name)
757        return x
758      nest.map_structure(fn, result, expand_composites=True)
759
760  def GetWhileContext(self):
761    """Return the while context containing this context."""
762    if self._outer_context:
763      return self._outer_context.GetWhileContext()
764    return None
765
766  def _RemoveExternalControlEdges(self, op):
767    """Remove any external control dependency on this op."""
768    while_ctxt = self.GetWhileContext()
769    # A control input of `op` is internal if it is in the same while
770    # loop context as the enclosing while loop context of self.
771    if while_ctxt is None:
772      internal_control_inputs = op.control_inputs
773    else:
774      internal_control_inputs = []
775      for x in op.control_inputs:
776        ctxt = util.GetOutputContext(x)
777        if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
778          internal_control_inputs.append(x)
779    external_control_inputs = []
780    if len(internal_control_inputs) != len(op.control_inputs):
781      external_control_inputs = list(
782          set(op.control_inputs) - set(internal_control_inputs))
783      op._remove_all_control_inputs()
784      op._add_control_inputs(internal_control_inputs)
785    return internal_control_inputs, external_control_inputs
786
787  # pylint: enable=protected-access
788
789  def AddInnerOp(self, op):
790    """Notifies a scope about an operator added to an inner scope."""
791    if self._outer_context:
792      self._outer_context.AddInnerOp(op)
793
794  def GetControlPivot(self):
795    """Returns the pivot node for this context, or None."""
796    return None
797
798  def IsWhileContext(self):
799    return False
800
801  def IsCondContext(self):
802    return False
803
804  def IsXLAContext(self):
805    return False
806
807  def __str__(self):
808    return self.name
809
810
811class CondContext(ControlFlowContext):
812  """The context for the conditional construct."""
813
814  def __init__(self,
815               pred=None,
816               pivot=None,
817               branch=None,
818               name="cond_text",
819               context_def=None,
820               import_scope=None):
821    """Creates a `CondContext`.
822
823    Args:
824      pred: The `boolean` tensor for the conditional predicate.
825      pivot: The predicate tensor in this branch.
826      branch: 0 or 1 representing this branch.
827      name: Name of the `CondContext` python object.
828      context_def: Optional `ContextDef` protocol buffer to initialize the
829        `CondContext` object from.
830      import_scope: Optional `string`. Name scope to add. Only used when
831        initialing from protocol buffer.
832    """
833    self._name = ops.get_default_graph().unique_name(name)
834
835    if context_def:
836      self._init_from_proto(context_def, import_scope=import_scope)
837    else:
838      # Initializes the default fields.
839      ControlFlowContext.__init__(self)
840      self._pred = pred  # The boolean tensor for the cond predicate
841      self._pivot = pivot  # The predicate tensor in this branch
842      self._branch = branch  # 0 or 1 representing this branch
843
844      # Values considered to have been already seen in this context. pred is not
845      # included in this context.
846      self._values.add(pred.name)
847      self._external_values[pred.name] = pred
848      self._values.add(pivot.name)
849      pivot.op._set_control_flow_context(self)  # pylint: disable=protected-access
850
851  def _init_from_proto(self, context_def, import_scope=None):
852    """Creates a new `CondContext` from protocol buffer.
853
854    Args:
855      context_def: `CondContextDef` protocol buffer.
856      import_scope: Optional `string`. Name scope to add.
857    """
858    assert isinstance(context_def, control_flow_pb2.CondContextDef)
859    # Create from context_def.
860    g = ops.get_default_graph()
861    self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
862    self._pred = g.as_graph_element(
863        ops.prepend_name_scope(context_def.pred_name, import_scope))
864    self._pivot = g.as_graph_element(
865        ops.prepend_name_scope(context_def.pivot_name, import_scope))
866    self._branch = context_def.branch
867    super(CondContext, self).__init__(
868        values_def=context_def.values_def, import_scope=import_scope)
869
870  @property
871  def pred(self):
872    return self._pred
873
874  @property
875  def pivot(self):
876    return self._pivot
877
878  @property
879  def branch(self):
880    return self._branch
881
882  @property
883  def grad_state(self):
884    if self.GetWhileContext():
885      return self.GetWhileContext().grad_state
886    return None
887
888  @property
889  def back_prop(self):
890    if self.GetWhileContext():
891      self.GetWhileContext().back_prop
892    return False
893
894  def GetControlPivot(self):
895    return self._pivot
896
897  def to_proto(self, export_scope=None):
898    """Converts a `CondContext` to a `CondContextDef` protocol buffer.
899
900    Args:
901      export_scope: Optional `string`. Name scope to remove.
902
903    Returns:
904      A `CondContextDef` protocol buffer.
905    """
906    if (export_scope is None or self.name.startswith(export_scope)):
907      context_def = control_flow_pb2.CondContextDef()
908      context_def.context_name = ops.strip_name_scope(self.name, export_scope)
909      context_def.pred_name = ops.strip_name_scope(self._pred.name,
910                                                   export_scope)
911      context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
912                                                    export_scope)
913      context_def.branch = self._branch
914      context_def.values_def.MergeFrom(
915          super(CondContext, self)._to_values_def(export_scope))
916      for nested in self._nested_contexts:
917        nested_def = context_def.nested_contexts.add()
918        nested.to_control_flow_context_def(nested_def)
919
920      return context_def
921    else:
922      return None
923
924  @staticmethod
925  def from_proto(context_def, import_scope=None):
926    """Returns a `CondContext` object created from `context_def`."""
927    ret = CondContext(context_def=context_def, import_scope=import_scope)
928
929    ret.Enter()
930    for nested_def in context_def.nested_contexts:
931      from_control_flow_context_def(nested_def, import_scope=import_scope)
932    ret.Exit()
933    return ret
934
935  def to_control_flow_context_def(self, context_def, export_scope=None):
936    context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
937
938  def AddValue(self, val):
939    """Add `val` to the current context and its outer context recursively."""
940    if val.name in self._values:
941      # Use the real value if it comes from outer context. This is needed in
942      # particular for nested conds.
943      result = self._external_values.get(val.name)
944      result = val if result is None else result
945    else:
946      result = val
947      self._values.add(val.name)
948      if self._outer_context:
949        result = self._outer_context.AddValue(val)
950        self._values.add(result.name)
951        self._external_values[result.name] = result
952      with ops.control_dependencies(None):
953        result = _SwitchRefOrTensor(result, self._pred)[self._branch]
954        if self._outer_context:
955          self._outer_context.AddInnerOp(result.op)
956
957      result.op.graph.prevent_fetching(result.op)
958      # pylint: disable=protected-access
959      result.op._set_control_flow_context(self)
960      # pylint: enable=protected-access
961
962      # Mark Switch output as seen by this context and any outer contexts,
963      # just like what we do for normal op outputs in _AddOpInternal() below.
964      ctxt = self
965      while ctxt is not None:
966        # pylint: disable=protected-access
967        ctxt._values.add(result.name)
968        ctxt = ctxt._outer_context
969        # pylint: enable=protected-access
970
971      self._external_values[val.name] = result
972    return result
973
974  def AddOp(self, op):
975    self._AddOpInternal(op)
976
977  def _AddOpInternal(self, op):
978    """Add `op` to the current context."""
979    if not op.inputs:
980      # If we're in a while loop, remove any control inputs from outside the
981      # loop.
982      self._RemoveExternalControlEdges(op)
983
984      if not any(
985          util.OpInContext(input_op, self) for input_op in op.control_inputs):
986        # pylint: disable=protected-access
987        op._add_control_input(self._pivot.op)
988        # pylint: enable=protected-access
989    else:
990      # Make each input to 'op' available in this CondContext. If an input is
991      # already part of this context there's nothing to do, but if it's
992      # external, AddValue() will handle adding the appropriate Switch node and
993      # other bookkeeping.
994      for index in range(len(op.inputs)):
995        x = op.inputs[index]
996        if op.type == "Merge" and x.op.type == "NextIteration":
997          # Edge case: if we're importing a while loop inside this CondContext,
998          # AddValue() will not correctly handle the NextIteration inputs to
999          # Merge node. The problem is that the NextIteration should also be
1000          # part of this context, but if we're importing it won't have been
1001          # processed and added to the context yet, so AddValue() will try to
1002          # add a Switch which results in an invalid graph. Instead, we use the
1003          # NextIteration input as-is here, and it will eventually be added to
1004          # the context via AddOp().
1005          real_x = x
1006        else:
1007          real_x = self.AddValue(x)
1008        if real_x != x:
1009          # pylint: disable=protected-access
1010          op._update_input(index, real_x)
1011          # pylint: enable=protected-access
1012      # Remove any external control dependency on this op.
1013      self._RemoveExternalControlEdges(op)
1014      # pylint: disable=protected-access
1015      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
1016        op._add_control_input(self._pivot.op)
1017      # pylint: enable=protected-access
1018
1019    # Mark op's outputs as seen by this context and any outer contexts.
1020    output_names = [x.name for x in op.outputs]
1021    ctxt = self
1022    while ctxt is not None:
1023      # pylint: disable=protected-access
1024      ctxt._values.update(output_names)
1025      ctxt = ctxt._outer_context
1026      # pylint: enable=protected-access
1027
1028    if self._outer_context or not util.IsLoopExit(op):
1029      op.graph.prevent_fetching(op)
1030
1031    if self._outer_context:
1032      self._outer_context.AddInnerOp(op)
1033
1034  def _ProcessOutputTensor(self, val):
1035    """Process an output tensor of a conditional branch."""
1036    real_val = val
1037    if val.name not in self._values:
1038      # Handle the special case of lambda: x
1039      self._values.add(val.name)
1040      if self._outer_context:
1041        real_val = self._outer_context.AddValue(val)
1042        self._values.add(real_val.name)
1043        self._external_values[real_val.name] = real_val
1044      real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
1045      self._external_values[val.name] = real_val
1046    else:
1047      external_val = self._external_values.get(val.name)
1048      if external_val is not None:
1049        real_val = external_val
1050    return real_val
1051
1052  def _BuildCondTensor(self, v):
1053    if isinstance(v, ops.Operation):
1054      # Use pivot as the proxy for this op.
1055      return with_dependencies([v], self._pivot)
1056    else:
1057      v = nest.map_structure(
1058          _convert_tensorarray_to_flow, v, expand_composites=True)
1059      return self._ProcessOutputTensor(ops.convert_to_tensor(v))
1060
1061  def BuildCondBranch(self, fn):
1062    """Add the subgraph defined by fn() to the graph."""
1063    pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1064    original_result = fn()
1065    post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1066    if len(post_summaries) > len(pre_summaries):
1067      new_summaries = post_summaries[len(pre_summaries):]
1068      summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1069      summary_ref[:] = pre_summaries
1070      with ops.control_dependencies(new_summaries):
1071        if original_result is None:
1072          return no_op(), None
1073        elif not isinstance(original_result, ops.Operation):
1074          original_result = nest.map_structure(
1075              array_ops.identity, original_result, expand_composites=True)
1076    if original_result is None:
1077      return None, None
1078
1079    result = nest.map_structure(
1080        self._BuildCondTensor, original_result, expand_composites=True)
1081    if not isinstance(result, (list, _basetuple)):
1082      result = [result]
1083    return original_result, result
1084
1085  def IsCondContext(self):
1086    return True
1087
1088
1089def _UnpackIfSingleton(res):
1090  if isinstance(res, (list, _basetuple)) and len(res) == 1:
1091    return res[0]
1092  else:
1093    return res
1094
1095
1096# pylint: disable=redefined-outer-name
1097# pylint: disable=g-doc-args
1098@tf_export(v1=["cond"])
1099@deprecation.deprecated_args(
1100    None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
1101    "fn1", "fn2")
1102def cond(pred,
1103         true_fn=None,
1104         false_fn=None,
1105         strict=False,
1106         name=None,
1107         fn1=None,
1108         fn2=None):
1109  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
1110
1111  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
1112  `false_fn` must have the same non-zero number and type of outputs.
1113
1114  **WARNING**: Any Tensors or Operations created outside of `true_fn` and
1115  `false_fn` will be executed regardless of which branch is selected at runtime.
1116
1117  Although this behavior is consistent with the dataflow model of TensorFlow,
1118  it has frequently surprised users who expected a lazier semantics.
1119  Consider the following simple program:
1120
1121  ```python
1122  z = tf.multiply(a, b)
1123  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
1124  ```
1125
1126  If `x < y`, the `tf.add` operation will be executed and `tf.square`
1127  operation will not be executed. Since `z` is needed for at least one
1128  branch of the `cond`, the `tf.multiply` operation is always executed,
1129  unconditionally.
1130
1131  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
1132  call to `cond`, and not at all during `Session.run()`). `cond`
1133  stitches together the graph fragments created during the `true_fn` and
1134  `false_fn` calls with some additional graph nodes to ensure that the right
1135  branch gets executed depending on the value of `pred`.
1136
1137  `tf.cond` supports nested structures as implemented in
1138  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
1139  same (possibly nested) value structure of lists, tuples, and/or named tuples.
1140  Singleton lists and tuples form the only exceptions to this: when returned by
1141  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
1142  This behavior is disabled by passing `strict=True`.
1143
1144  Args:
1145    pred: A scalar determining whether to return the result of `true_fn` or
1146      `false_fn`.
1147    true_fn: The callable to be performed if pred is true.
1148    false_fn: The callable to be performed if pred is false.
1149    strict: A boolean that enables/disables 'strict' mode; see above.
1150    name: Optional name prefix for the returned tensors.
1151
1152  Returns:
1153    Tensors returned by the call to either `true_fn` or `false_fn`. If the
1154    callables return a singleton list, the element is extracted from the list.
1155
1156  Raises:
1157    TypeError: if `true_fn` or `false_fn` is not callable.
1158    ValueError: if `true_fn` and `false_fn` do not return the same number of
1159      tensors, or return tensors of different types.
1160
1161  Example:
1162
1163  ```python
1164  x = tf.constant(2)
1165  y = tf.constant(5)
1166  def f1(): return tf.multiply(x, 17)
1167  def f2(): return tf.add(y, 23)
1168  r = tf.cond(tf.less(x, y), f1, f2)
1169  # r is set to f1().
1170  # Operations in f2 (e.g., tf.add) are not executed.
1171  ```
1172
1173  """
1174  # Always enable control flow v2 if building a function, regardless of toggle.
1175  if (util.EnableControlFlowV2(ops.get_default_graph()) and
1176      not context.executing_eagerly()):
1177    return cond_v2.cond_v2(pred, true_fn, false_fn, name)
1178
1179  # We needed to make true_fn/false_fn keyword arguments for
1180  # backwards-compatibility. This check exists so that we can convert back to
1181  # having them be positional arguments.
1182  # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
1183  # `fn1` and `fn2` are deleted.
1184  if fn1 is not None:
1185    if true_fn is not None:
1186      raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
1187    true_fn = fn1
1188  elif true_fn is None:
1189    raise TypeError("cond(): true_fn argument required")
1190  if fn2 is not None:
1191    if false_fn is not None:
1192      raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
1193    false_fn = fn2
1194  elif false_fn is None:
1195    raise TypeError("cond(): false_fn argument required")
1196
1197  if not callable(true_fn):
1198    raise TypeError("true_fn must be callable.")
1199  if not callable(false_fn):
1200    raise TypeError("false_fn must be callable.")
1201
1202  with ops.name_scope(name, "cond", [pred]):
1203    if context.executing_eagerly():
1204      if pred:
1205        result = true_fn()
1206      else:
1207        result = false_fn()
1208      if not strict:
1209        result = _UnpackIfSingleton(result)
1210      return result
1211
1212    # Add the Switch to the graph.
1213    if isinstance(pred, bool):
1214      raise TypeError("pred must not be a Python bool")
1215    p_2, p_1 = switch(pred, pred)
1216    pivot_1 = array_ops.identity(p_1, name="switch_t")
1217    pivot_2 = array_ops.identity(p_2, name="switch_f")
1218    pred = array_ops.identity(pred, name="pred_id")
1219    # Disable the fetching of tensors that are only on one branch of cond.
1220    for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:
1221      tensor.op.graph.prevent_fetching(tensor.op)
1222
1223    # Build the graph for the true branch in a new context.
1224    context_t = CondContext(pred, pivot_1, branch=1)
1225    try:
1226      context_t.Enter()
1227      orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
1228      if orig_res_t is None:
1229        raise ValueError("true_fn must have a return value.")
1230      context_t.ExitResult(res_t)
1231    finally:
1232      context_t.Exit()
1233
1234    # Build the graph for the false branch in a new context.
1235    context_f = CondContext(pred, pivot_2, branch=0)
1236    try:
1237      context_f.Enter()
1238      orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
1239      if orig_res_f is None:
1240        raise ValueError("false_fn must have a return value.")
1241      context_f.ExitResult(res_f)
1242    finally:
1243      context_f.Exit()
1244
1245    if not strict:
1246      orig_res_t = _UnpackIfSingleton(orig_res_t)
1247      orig_res_f = _UnpackIfSingleton(orig_res_f)
1248
1249    # Check that the return values of the two branches have the same structure.
1250    try:
1251      nest.assert_same_structure(orig_res_t, orig_res_f, expand_composites=True)
1252    except (TypeError, ValueError):
1253      nest.map_structure(_cast_indexed_slice_indices, orig_res_t, orig_res_f)
1254      nest.map_structure(_cast_indexed_slice_indices, res_t, res_f)
1255      try:
1256        nest.assert_same_structure(orig_res_t, orig_res_f,
1257                                   expand_composites=True)
1258      except TypeError as e:
1259        raise TypeError(
1260            "Incompatible return types of true_fn and false_fn: {}".format(e))
1261      except ValueError as e:
1262        raise ValueError(
1263            "Incompatible return values of true_fn and false_fn: {}".format(e))
1264
1265    # Add the final merge to the graph.
1266    if not res_t:
1267      raise ValueError("true_fn and false_fn must return at least one result.")
1268
1269    res_t_flat = nest.flatten(res_t, expand_composites=True)
1270    res_f_flat = nest.flatten(res_f, expand_composites=True)
1271
1272    for (x, y) in zip(res_t_flat, res_f_flat):
1273      assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
1274      if x.dtype.base_dtype != y.dtype.base_dtype:
1275        raise ValueError(
1276            "Outputs of true_fn and false_fn must have the same type: "
1277            "%s, %s" % (x.dtype.name, y.dtype.name))
1278
1279    merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
1280    merges = _convert_flows_to_tensorarrays(
1281        nest.flatten(orig_res_t, expand_composites=True), merges)
1282
1283    # Only add non-nested conds to the collection. Any nested control flow will
1284    # be encapsulated in the root context.
1285    assert context_t.outer_context == context_f.outer_context
1286    if context_t.outer_context is None:
1287      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
1288      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
1289
1290    merges = nest.pack_sequence_as(
1291        structure=orig_res_t, flat_sequence=merges, expand_composites=True)
1292
1293    # Singleton lists and tuples are automatically unpacked if strict == False.
1294    if not strict:
1295      merges = _UnpackIfSingleton(merges)
1296    return merges
1297
1298
1299def _cast_indexed_slice_indices(a, b):
1300  """Cast IndexedSlice.indices from int32 to int64 where necessary.
1301
1302  If `a` and `b` are both IndexedSlices, and their indices have different
1303  dtypes, then cast both their dtypes to `int64` (modifies `a` and `b`
1304  in-place).  Otherwise, does nothing.
1305
1306  Args:
1307    a: A value, which may be an IndexedSlices.
1308    b: A value, which may be an IndexedSlices.
1309  """
1310  if (isinstance(a, ops.IndexedSlices) and isinstance(b, ops.IndexedSlices)
1311      and a.indices.dtype != b.indices.dtype):
1312    # pylint: disable=protected-access
1313    a._indices = math_ops.cast(a.indices, dtypes.int64)
1314    b._indices = math_ops.cast(b.indices, dtypes.int64)
1315
1316
1317# pylint: enable=g-doc-args
1318# pylint: enable=redefined-outer-name
1319
1320
1321@tf_export("cond", v1=[])
1322def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None):
1323  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
1324
1325  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
1326  `false_fn` must have the same non-zero number and type of outputs.
1327
1328  **WARNING**: Any Tensors or Operations created outside of `true_fn` and
1329  `false_fn` will be executed regardless of which branch is selected at runtime.
1330
1331  Although this behavior is consistent with the dataflow model of TensorFlow,
1332  it has frequently surprised users who expected a lazier semantics.
1333  Consider the following simple program:
1334
1335  ```python
1336  z = tf.multiply(a, b)
1337  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
1338  ```
1339
1340  If `x < y`, the `tf.add` operation will be executed and `tf.square`
1341  operation will not be executed. Since `z` is needed for at least one
1342  branch of the `cond`, the `tf.multiply` operation is always executed,
1343  unconditionally.
1344
1345  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
1346  call to `cond`, and not at all during `Session.run()`). `cond`
1347  stitches together the graph fragments created during the `true_fn` and
1348  `false_fn` calls with some additional graph nodes to ensure that the right
1349  branch gets executed depending on the value of `pred`.
1350
1351  `tf.cond` supports nested structures as implemented in
1352  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
1353  same (possibly nested) value structure of lists, tuples, and/or named tuples.
1354  Singleton lists and tuples form the only exceptions to this: when returned by
1355  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
1356
1357  Note: It is illegal to "directly" use tensors created inside a cond branch
1358  outside it, e.g. by storing a reference to a branch tensor in the python
1359  state. If you need to use a tensor created in a branch function you should
1360  return it as an output of the branch function and use the output from
1361  `tf.cond` instead.
1362
1363  Args:
1364    pred: A scalar determining whether to return the result of `true_fn` or
1365      `false_fn`.
1366    true_fn: The callable to be performed if pred is true.
1367    false_fn: The callable to be performed if pred is false.
1368    name: Optional name prefix for the returned tensors.
1369
1370  Returns:
1371    Tensors returned by the call to either `true_fn` or `false_fn`. If the
1372    callables return a singleton list, the element is extracted from the list.
1373
1374  Raises:
1375    TypeError: if `true_fn` or `false_fn` is not callable.
1376    ValueError: if `true_fn` and `false_fn` do not return the same number of
1377      tensors, or return tensors of different types.
1378
1379  Example:
1380
1381  ```python
1382  x = tf.constant(2)
1383  y = tf.constant(5)
1384  def f1(): return tf.multiply(x, 17)
1385  def f2(): return tf.add(y, 23)
1386  r = tf.cond(tf.less(x, y), f1, f2)
1387  # r is set to f1().
1388  # Operations in f2 (e.g., tf.add) are not executed.
1389  ```
1390
1391  """
1392  return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
1393
1394
1395def _resource_safe_shape(t):
1396  """Returns the shape of t or the variable it points to."""
1397  if t.dtype == dtypes.resource:
1398    while t.op.inputs:
1399      t = t.op.inputs[0]
1400    return tensor_shape.TensorShape(t.op.get_attr("shape"))
1401  return array_ops.shape_internal(t, optimize=False)
1402
1403
1404# TODO(yuanbyu): Consider having a unified notion of context for
1405# not only conditionals and loops but also control dependency and
1406# subgraphs.
1407class WhileContext(ControlFlowContext):
1408  """The context for the loop construct."""
1409
1410  def __init__(self,
1411               maximum_iterations=None,
1412               parallel_iterations=10,
1413               back_prop=True,
1414               swap_memory=False,
1415               name="while_context",
1416               grad_state=None,
1417               context_def=None,
1418               import_scope=None):
1419    """"Creates a `WhileContext`.
1420
1421    Args:
1422      maximum_iterations: Optional upper bound on number of loop iterations.
1423      parallel_iterations: The number of iterations allowed to run in parallel.
1424      back_prop: Whether backprop is enabled for this while loop.
1425      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
1426      name: Optional name prefix for the returned tensors.
1427      grad_state: The gradient loop state.
1428      context_def: Optional `WhileContextDef` protocol buffer to initialize the
1429        `Whilecontext` python object from.
1430      import_scope: Optional `string`. Name scope to add. Only used when
1431        initialing from protocol buffer.
1432    """
1433    if context_def:
1434      self._init_from_proto(context_def, import_scope=import_scope)
1435    else:
1436      ControlFlowContext.__init__(self)
1437      self._init_from_args(maximum_iterations, parallel_iterations, back_prop,
1438                           swap_memory, name)
1439    # The gradient loop state.
1440    self._grad_state = grad_state
1441
1442  def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop,
1443                      swap_memory, name):
1444    """Creates a new `WhileContext` from arguments.
1445
1446    Args:
1447      maximum_iterations: Optional upper bound on number of loop iterations.
1448      parallel_iterations: The number of iterations allowed to run in parallel.
1449      back_prop: Whether backprop is enabled for this while loop.
1450      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
1451      name: Optional name prefix for the returned tensors.
1452
1453    Raises:
1454      ValueError: If `parallel_iterations` has invalid value.
1455    """
1456    if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
1457      raise ValueError("`parallel_iterations` must be a positive integer: "
1458                       "%s" % parallel_iterations)
1459    self._name = ops.get_default_graph().unique_name(name)
1460    self._maximum_iterations = maximum_iterations
1461    self._parallel_iterations = parallel_iterations
1462    self._back_prop = back_prop
1463    self._swap_memory = swap_memory
1464    # We use this node to control constants created by the pred lambda.
1465    self._pivot_for_pred = None
1466    # We use this node to control constants created by the body lambda.
1467    self._pivot_for_body = None
1468    # The boolean tensor for loop termination condition. Used in code
1469    # generation for gradient computation
1470    self._pivot = None
1471    # The list of exit tensors for loop variables.
1472    self._loop_exits = []
1473    # The list of enter tensors for loop variables.
1474    self._loop_enters = []
1475    self._graph = ops.get_default_graph()
1476
1477  def _init_from_proto(self, context_def, import_scope=None):
1478    """Creates a new `WhileContext` from protocol buffer.
1479
1480    Args:
1481      context_def: `WhileContextDef` protocol buffer.
1482      import_scope: Optional `string`. Name scope to add.
1483    """
1484    assert isinstance(context_def, control_flow_pb2.WhileContextDef)
1485    # Create from context_def.
1486    g = ops.get_default_graph()
1487    self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
1488    if context_def.maximum_iterations_name:
1489      self._maximum_iterations = g.as_graph_element(
1490          ops.prepend_name_scope(context_def.maximum_iterations_name,
1491                                 import_scope))
1492    else:
1493      self._maximum_iterations = None
1494    self._parallel_iterations = context_def.parallel_iterations
1495    self._back_prop = context_def.back_prop
1496    self._swap_memory = context_def.swap_memory
1497    self._pivot_for_pred = g.as_graph_element(
1498        ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
1499    # We use this node to control constants created by the body lambda.
1500    self._pivot_for_body = g.as_graph_element(
1501        ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
1502    # The boolean tensor for loop termination condition. Used in code
1503    # generation for gradient computation.
1504    self._pivot = g.as_graph_element(
1505        ops.prepend_name_scope(context_def.pivot_name, import_scope))
1506    # The list of exit tensors for loop variables.
1507    self._loop_exits = [
1508        g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
1509        for exit_name in context_def.loop_exit_names
1510    ]
1511    # The list of enter tensors for loop variables.
1512    self._loop_enters = [
1513        g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
1514        for enter_name in context_def.loop_enter_names
1515    ]
1516    super(WhileContext, self).__init__(
1517        values_def=context_def.values_def, import_scope=import_scope)
1518
1519    # import_scope causes self.name to be different from the original serialized
1520    # context's name. Rewrite "frame_name" attrs with the new name.
1521    if import_scope:
1522      for tensor_name in self._values:
1523        op = g.as_graph_element(tensor_name).op
1524        if util.IsLoopEnter(op):
1525          # pylint: disable=protected-access
1526          op._set_attr("frame_name",
1527                       attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
1528          # pylint: enable=protected-access
1529    self._graph = ops.get_default_graph()
1530
1531  @property
1532  def maximum_iterations(self):
1533    """The maximum number of iterations that will be executed."""
1534    return self._maximum_iterations
1535
1536  @property
1537  def parallel_iterations(self):
1538    """The number of iterations allowed to run in parallel."""
1539    return self._parallel_iterations
1540
1541  @property
1542  def back_prop(self):
1543    """True iff backprop is enabled for this while loop."""
1544    return self._back_prop
1545
1546  @property
1547  def swap_memory(self):
1548    """True iff GPU-CPU memory swap is enabled for this while loop."""
1549    return self._swap_memory
1550
1551  @property
1552  def pivot(self):
1553    """The boolean tensor representing the loop termination condition."""
1554    return self._pivot
1555
1556  @property
1557  def loop_enters(self):
1558    """The list of enter tensors for loop variables."""
1559    return self._loop_enters
1560
1561  @property
1562  def loop_exits(self):
1563    """The list of exit tensors for loop variables."""
1564    return self._loop_exits
1565
1566  @property
1567  def grad_state(self):
1568    """The gradient loop state."""
1569    return self._grad_state
1570
1571  def to_proto(self, export_scope=None):
1572    """Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
1573
1574    Args:
1575      export_scope: Optional `string`. Name scope to remove.
1576
1577    Returns:
1578      A `WhileContextDef` protocol buffer.
1579    """
1580    if (export_scope is None or self.name.startswith(export_scope)):
1581      context_def = control_flow_pb2.WhileContextDef()
1582      context_def.context_name = ops.strip_name_scope(self.name, export_scope)
1583      context_def.parallel_iterations = self._parallel_iterations
1584      if self._maximum_iterations is not None:
1585        context_def.maximum_iterations_name = ops.strip_name_scope(
1586            self._maximum_iterations.name, export_scope)
1587      context_def.back_prop = self._back_prop
1588      context_def.swap_memory = self._swap_memory
1589      context_def.pivot_for_pred_name = ops.strip_name_scope(
1590          self._pivot_for_pred.name, export_scope)
1591      context_def.pivot_for_body_name = ops.strip_name_scope(
1592          self._pivot_for_body.name, export_scope)
1593      context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
1594                                                    export_scope)
1595      context_def.loop_exit_names.extend([
1596          ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
1597      ])
1598      context_def.loop_enter_names.extend([
1599          ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
1600      ])
1601      context_def.values_def.MergeFrom(
1602          super(WhileContext, self)._to_values_def(export_scope=export_scope))
1603      for nested in self._nested_contexts:
1604        nested_def = context_def.nested_contexts.add()
1605        nested.to_control_flow_context_def(nested_def)
1606
1607      return context_def
1608    else:
1609      return None
1610
1611  def to_control_flow_context_def(self, context_def, export_scope=None):
1612    context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
1613
1614  @staticmethod
1615  def from_proto(context_def, import_scope=None):
1616    """Returns a `WhileContext` object created from `context_def`.
1617
1618    Args:
1619      context_def: A `WhileContextDef` protocol buffer.
1620      import_scope: Optional `string`. Name scope to add.
1621
1622    Returns:
1623      A `WhileContext` Python object.
1624    """
1625    ret = WhileContext(context_def=context_def, import_scope=import_scope)
1626    ret.Enter()
1627    for nested_def in context_def.nested_contexts:
1628      from_control_flow_context_def(nested_def, import_scope=import_scope)
1629    ret.Exit()
1630    return ret
1631
1632  def GetWhileContext(self):
1633    return self
1634
1635  def GetControlPivot(self):
1636    if self._pivot_for_body is not None:
1637      return self._pivot_for_body
1638    return self._pivot_for_pred
1639
1640  def AddValue(self, val):
1641    """Add `val` to the current context and its outer context recursively."""
1642    result = val
1643    new_value = val.name not in self._values
1644    # Don't treat ops in this context as new values. Usually all known values
1645    # are in self._values, except when we're importing a while loop inside this
1646    # WhileContext. Since there's a cycle in this case, `val` may be part of the
1647    # imported while loop but not yet processed by this context and added to
1648    # self._values in _AddOpInternal. We only want to process external input
1649    # tensors to the while loop here.
1650    new_value &= val.op._control_flow_context is not self  # pylint: disable=protected-access
1651    if new_value:
1652      self._values.add(val.name)
1653
1654      # If we are in a grad context and val is from its forward context,
1655      # use GetRealValue(), which adds the logic to save the history of
1656      # val in forward.
1657      grad_ctxt = ops.get_default_graph()._get_control_flow_context()
1658      if grad_ctxt:
1659        grad_ctxt = grad_ctxt.GetWhileContext()
1660        if grad_ctxt.grad_state:
1661          forward_ctxt = util.GetWhileContext(val.op)
1662          if util.IsLoopExit(val.op):
1663            forward_ctxt = forward_ctxt.outer_context
1664            if forward_ctxt:
1665              forward_ctxt = forward_ctxt.GetWhileContext()
1666          if forward_ctxt == grad_ctxt.grad_state.forward_context:
1667            real_val = grad_ctxt.grad_state.GetRealValue(val)
1668            self._external_values[val.name] = real_val
1669            return real_val
1670
1671      if self._outer_context is not None:
1672        result = self._outer_context.AddValue(val)
1673      # Create an Enter to make `result` known to this loop context.
1674      with ops.control_dependencies(None):
1675        enter = _Enter(
1676            result,
1677            self._name,
1678            is_constant=True,
1679            parallel_iterations=self._parallel_iterations)
1680        enter.graph.prevent_feeding(enter)
1681        if self._outer_context:
1682          self._outer_context.AddInnerOp(enter.op)
1683      # Fix the control inputs and control flow context of these enter ops.
1684      self._FixControlInputsAndContext([enter])
1685
1686      # Add `enter` in this context.
1687      self._values.add(enter.name)
1688      self._external_values[val.name] = enter
1689      result = enter
1690    else:
1691      actual_val = self._external_values.get(val.name)
1692      if actual_val is not None:
1693        result = actual_val
1694    return result
1695
1696  def AddOp(self, op):
1697    """Add `op` to the current context."""
1698    # For a reduction op, if op is in a grad context and its input is from
1699    # its forward context, moving op to the forward context means we would
1700    # store the tensor after the reduction as opposed to the tensor before
1701    # reduction, and therefore could significantly reduce memory consumption.
1702    # For now, we do this only for a few ops.
1703    #
1704    # If in XLA context, do not move constant ops to forward pass as pushing to
1705    # and popping from a stack removes the constant property of an op and breaks
1706    # XLA compilation, which requires certain inputs to be constant for certain
1707    # ops.
1708    if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}:
1709      grad_ctxt = ops.get_default_graph()._get_control_flow_context()
1710      if grad_ctxt:
1711        grad_ctxt = grad_ctxt.GetWhileContext()
1712        if grad_ctxt.grad_state:
1713          op_input_forward_ctxt = util.GetWhileContext(op.inputs[0].op)
1714          if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
1715            op_input_ctxt = op.inputs[0].op._get_control_flow_context()
1716            op._set_control_flow_context(op_input_ctxt)
1717            op_input_ctxt._AddOpInternal(op)
1718            return
1719    self._AddOpInternal(op)
1720
1721  def _AddOpInternal(self, op):
1722    """Add `op` to the current context.
1723
1724    We move any external control dependencies of the op to the loop pivot, to
1725    ensure they get executed.
1726    """
1727    if not op.inputs:
1728      # Remove any external control dependency on this op
1729      control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
1730      # Add a control edge from the control pivot to this op.
1731      if not control_inputs:
1732        # pylint: disable=protected-access
1733        op._add_control_input(self.GetControlPivot().op)
1734        # pylint: enable=protected-access
1735      for x in op.outputs:
1736        self._values.add(x.name)
1737    else:
1738      for index in range(len(op.inputs)):
1739        x = op.inputs[index]
1740        real_x = self.AddValue(x)
1741        if real_x != x:
1742          op._update_input(index, real_x)  # pylint: disable=protected-access
1743      # Remove any external control dependency on this op.
1744      _, external_inputs = self._RemoveExternalControlEdges(op)
1745      # Add a control dependency to prevent loop invariants from
1746      # enabling ops that should not be executed.
1747      self._MaybeAddControlDependency(op)
1748      for x in op.outputs:
1749        self._values.add(x.name)
1750    if external_inputs:
1751      # Use an identity to pull control inputs as data inputs. Note that we
1752      # ignore ops which don't have outputs. TODO(apassos): fix that
1753      with ops.control_dependencies(None):
1754        self.Enter()
1755        external_inputs = [
1756            array_ops.identity(x.outputs[0]).op
1757            for x in external_inputs
1758            if x.outputs
1759        ]
1760        self.Exit()
1761      op._add_control_inputs(external_inputs)  # pylint: disable=protected-access
1762    if self._outer_context or not util.IsLoopExit(op):
1763      op.graph.prevent_fetching(op)
1764      for x in op.outputs:
1765        op.graph.prevent_feeding(x)
1766
1767    if self._outer_context:
1768      self._outer_context.AddInnerOp(op)
1769
1770  def _MaybeAddControlDependency(self, op):
1771    """Add a control input to the op if it only depends on loop invariants."""
1772
1773    def _IsOpFree(op):
1774      """Determines if `op` needs a control dependency."""
1775      if op.control_inputs:
1776        return False
1777      # pylint: disable=protected-access
1778      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
1779        return True
1780      # pylint: enable=protected-access
1781      for x in op.inputs:
1782        if not util.IsLoopConstantEnter(x.op):
1783          return False
1784      return True
1785
1786    if _IsOpFree(op):
1787      # pylint: disable=protected-access
1788      op._add_control_input(self.GetControlPivot().op)
1789      # pylint: enable=protected-access
1790
1791  def AddForwardLoopCounter(self, outer_grad_state):
1792    """Adds a loop that counts the number of iterations.
1793
1794    This is added to the forward loop at the time when we start to
1795    create the loop for backprop gradient computation. Called in
1796    the outer context of this forward context.
1797
1798    The pseudocode is:
1799      `n = 0; while (_pivot) { n++; }`
1800
1801    Note that a control dependency is added to `n` to ensure the correct
1802    execution order of stack push ops.
1803
1804    Args:
1805      outer_grad_state: The outer grad state. None if not nested.
1806
1807    Returns:
1808      The number of iterations taken by the forward loop and the loop index.
1809    """
1810    n = constant_op.constant(0, name="f_count")
1811    if outer_grad_state is not None:
1812      # Force the stack pushes of i-th execution of an inner loop to be ordered
1813      # before the pushes of (i+1)-th execution of the same inner loop.
1814      outer_add_op = outer_grad_state.forward_index.op.inputs[0].op
1815      n.op._add_control_input(outer_add_op)  # pylint: disable=protected-access
1816
1817    self.Enter()
1818    self.AddName(n.name)
1819    enter_n = _Enter(
1820        n,
1821        self._name,
1822        is_constant=False,
1823        parallel_iterations=self._parallel_iterations,
1824        name="f_count")
1825    self.loop_enters.append(enter_n)
1826
1827    merge_n = merge([enter_n, enter_n])[0]
1828    switch_n = switch(merge_n, self._pivot)
1829
1830    index = math_ops.add(switch_n[1], 1)
1831    next_n = _NextIteration(index)
1832    merge_n.op._update_input(1, next_n)
1833
1834    total_iterations = exit(switch_n[0], name="f_count")
1835    self.loop_exits.append(total_iterations)
1836    self.ExitResult([total_iterations])
1837    self.Exit()
1838    return total_iterations, next_n
1839
1840  def AddBackpropLoopCounter(self, count, outer_grad_state):
1841    """Add the backprop loop that controls the iterations.
1842
1843    This is added to the backprop loop. It is used to control the loop
1844    termination of the backprop loop. Called in the outer context of
1845    this grad context.
1846
1847    The pseudocode is:
1848      `n = count; while (n >= 1) { n--; }`
1849
1850    Note that a control dependency is added to `final_zero` to ensure the
1851    correct execution order of stack pop ops.
1852
1853    Args:
1854      count: The number of iterations for backprop.
1855      outer_grad_state: The outer grad state. None if not nested.
1856
1857    Returns:
1858      The loop index.
1859    """
1860    in_separate_functions = count.graph is not ops.get_default_graph()
1861    if in_separate_functions:
1862      # Brings the count into this graph
1863      count = array_ops.identity(count)
1864    else:
1865      # TODO(apassos) XLA expects this constant to be created outside the loop,
1866      # so doing that for now.
1867      one = constant_op.constant(1, name="b_count")
1868
1869    self.Enter()
1870    self.AddName(count.name)
1871    enter_count = _Enter(
1872        count,
1873        self._name,
1874        is_constant=False,
1875        parallel_iterations=self._parallel_iterations,
1876        name="b_count")
1877    self.loop_enters.append(enter_count)
1878
1879    merge_count = merge([enter_count, enter_count])[0]
1880    self._pivot_for_pred = merge_count
1881
1882    if in_separate_functions:
1883      one = constant_op.constant(1, name="b_count")
1884    pred = math_ops.greater_equal(merge_count, one)
1885    self._pivot = loop_cond(pred, name="b_count")
1886    switch_count = switch(merge_count, self._pivot)
1887
1888    index = math_ops.subtract(switch_count[1], one)
1889    self._pivot_for_body = index
1890    next_count = _NextIteration(index)
1891    merge_count.op._update_input(1, next_count)
1892
1893    final_zero = exit(switch_count[0], name="b_count")
1894    self.loop_exits.append(final_zero)
1895    if outer_grad_state is not None:
1896      # Force the stack pops of i-th execution of an inner loop to be ordered
1897      # before the pops of (i+1)-th execution of the same inner loop.
1898      # pylint: disable=protected-access
1899      outer_grad_state.grad_sync._add_control_input(final_zero.op)
1900      # pylint: enable=protected-access
1901
1902    self.ExitResult([final_zero])
1903    self.Exit()
1904    return next_count
1905
1906  def AddBackpropAccumulator(self, op, grad):
1907    """Add an accumulation loop for every loop invariant.
1908
1909    This is added to the backprop loop. It is used to accumulate partial
1910    gradients within each loop iteration. Called when in the gradient while
1911    context.
1912
1913    The pseudocode is:
1914      ```
1915      acc = 0.0;
1916      while (_pivot) {
1917        acc += grad;
1918      }
1919      ```
1920
1921    Args:
1922      op: The Enter op for a loop invariant.
1923      grad: The partial gradient of an iteration for a loop invariant.
1924
1925    Returns:
1926      The gradient for a loop invariant.
1927    """
1928    self.Exit()
1929    # Create a zeros tensor with the right shape for acc. If we don't
1930    # know the full shape statically, we will have to get the shape
1931    # dynamically from the forward inference. Getting the shape right
1932    # for the zeros is only needed for the base case when the loop exits
1933    # without running any iterations.
1934    shape = grad.get_shape()
1935    if shape.is_fully_defined():
1936      if self.outer_context:
1937        self.outer_context.Enter()
1938      acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
1939      if self.outer_context:
1940        self.outer_context.Exit()
1941    else:
1942      value = op.inputs[0]
1943      if (isinstance(self.outer_context, WhileContext) and
1944          self.outer_context.grad_state is not None):
1945        # We are in a nested while loop.
1946        forward_ctxt = self.grad_state.forward_context
1947        forward_ctxt.outer_context.Enter()
1948        zeros_shape = array_ops.shape_internal(value, optimize=False)
1949        forward_ctxt.outer_context.Exit()
1950        outer_grad_state = self.grad_state.outer_grad_state
1951        history_zeros_shape = outer_grad_state.AddForwardAccumulator(
1952            zeros_shape)
1953        self.outer_context.Enter()
1954        real_shape = outer_grad_state.AddBackpropAccumulatedValue(
1955            history_zeros_shape, zeros_shape)
1956        acc = array_ops.zeros(real_shape, grad.dtype)
1957        self.outer_context.Exit()
1958      else:
1959        if self.outer_context:
1960          self.outer_context.Enter()
1961        zeros_shape = array_ops.shape_internal(value, optimize=False)
1962        acc = array_ops.zeros(zeros_shape, grad.dtype)
1963        if self.outer_context:
1964          self.outer_context.Exit()
1965
1966    self.Enter()
1967    self.AddName(acc.name)
1968    enter_acc = _Enter(
1969        acc,
1970        self._name,
1971        is_constant=False,
1972        parallel_iterations=self._parallel_iterations,
1973        name="b_acc")
1974    self.loop_enters.append(enter_acc)
1975
1976    merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
1977    switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
1978
1979    add_acc = math_ops.add(switch_acc_true, grad)
1980    next_acc = _NextIteration(add_acc)
1981    merge_acc.op._update_input(1, next_acc)  # pylint: disable=protected-access
1982
1983    result_acc = exit(switch_acc_false, name="b_acc")
1984    self.loop_exits.append(result_acc)
1985    self.ExitResult([result_acc])
1986    return result_acc
1987
1988  def AddBackpropIndexedSlicesAccumulator(self, op, grad):
1989    """This is used for accumulating gradients that are IndexedSlices.
1990
1991    This is essentially the equivalent of AddBackpropAccumulator but optimized
1992    for things like updating embeddings from within a while loop.
1993
1994    Args:
1995      op: The Enter op for a loop invariant.
1996      grad: The partial gradients represented as an IndexedSlices.
1997
1998    Returns:
1999      The accumulated IndexedSlices gradient of the loop invariant.
2000    """
2001    values = grad.values
2002    indices = grad.indices
2003    dense_shape = grad.dense_shape
2004
2005    self.Exit()
2006    if self.outer_context:
2007      self.outer_context.Enter()
2008    if values.get_shape().is_fully_defined():
2009      values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] +
2010                                              values.get_shape().dims[1:])
2011      if self.outer_context:
2012        self.outer_context.Enter()
2013      values_acc = constant_op.constant(
2014          0, values.dtype, shape=values_shape, name="b_acc")
2015      if self.outer_context:
2016        self.outer_context.Exit()
2017    else:
2018      values_shape = _resource_safe_shape(op.inputs[0])[1:]
2019      values_shape = array_ops.concat([[1], values_shape], 0)
2020      values_acc = array_ops.zeros(values_shape, dtype=values.dtype)
2021    indices_acc = constant_op.constant([0], indices.dtype)
2022    shape_acc = None
2023    if dense_shape is not None:
2024      if dense_shape.get_shape().is_fully_defined():
2025        if self.outer_context:
2026          self.outer_context.Enter()
2027        shape_acc = constant_op.constant(
2028            0, dense_shape.dtype, shape=dense_shape.get_shape())
2029        if self.outer_context:
2030          self.outer_context.Exit()
2031      else:
2032        shape_acc = array_ops.zeros_like(
2033            array_ops.shape_internal(
2034                op.inputs[0], optimize=False, out_type=dense_shape.dtype),
2035            optimize=False)
2036
2037    if self.outer_context:
2038      self.outer_context.Exit()
2039
2040    self.Enter()
2041    self.AddName(values_acc.name)
2042    self.AddName(indices_acc.name)
2043    init_acc = [indices_acc, values_acc]
2044    if shape_acc is not None:
2045      self.AddName(shape_acc.name)
2046      init_acc.append(shape_acc)
2047
2048    # Set use_input_shape=False since the accumulator tensors will grow in
2049    # size. If use_input_shape=True, the _update_input call below will result in
2050    # incompatible shapes.
2051    enter_acc = [
2052        _Enter(
2053            x,
2054            self._name,
2055            is_constant=False,
2056            parallel_iterations=self._parallel_iterations,
2057            use_input_shape=False,
2058            name="b_acc") for x in init_acc
2059    ]
2060    # Manually set appropriate partial shapes.
2061    enter_acc[0].set_shape([None])
2062    if values_acc.shape.dims is not None:
2063      enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:])
2064    self.loop_enters.extend(enter_acc)
2065
2066    merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
2067    switch_acc = [switch(x, self._pivot) for x in merge_acc]
2068
2069    # The actual accumulation.
2070    acc_indexed_slices = [
2071        array_ops.concat([xa[1], xv], 0)
2072        for xa, xv in zip(switch_acc[:2], [indices, values])
2073    ]
2074    if shape_acc is not None:
2075      # For the shape we just keep the maximum
2076      acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
2077
2078    next_acc = [_NextIteration(x) for x in acc_indexed_slices]
2079    for xm, xn in zip(merge_acc, next_acc):
2080      xm.op._update_input(1, xn)  # pylint: disable=protected-access
2081
2082    exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]
2083    self.loop_exits.extend(exit_acc)
2084
2085    self.ExitResult(exit_acc)
2086    return ops.IndexedSlices(
2087        indices=exit_acc[0],
2088        values=exit_acc[1],
2089        dense_shape=exit_acc[2] if shape_acc is not None else None)
2090
2091  def _InitializeValues(self, values):
2092    """Makes the values known to this context."""
2093    self._values = set()
2094    for x in values:
2095      if isinstance(x, ops.Tensor):
2096        self._values.add(x.name)
2097      else:
2098        raise TypeError("Type %s not supported" % type(x))
2099
2100  def _BuildLoop(self, pred, body, original_loop_vars, loop_vars,
2101                 shape_invariants):
2102    """Core: Add the loop termination condition and body to the graph."""
2103    flat_loop_vars = nest.flatten(original_loop_vars, expand_composites=True)
2104
2105    # Let the context know the loop variables so the loop variables
2106    # would be added in the outer contexts properly.
2107    self._InitializeValues(loop_vars)
2108    real_vars = loop_vars
2109    if self._outer_context:
2110      real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
2111    with ops.control_dependencies(None):
2112      enter_vars = [
2113          _Enter(
2114              x,
2115              self._name,
2116              is_constant=False,
2117              parallel_iterations=self._parallel_iterations,
2118              use_input_shape=(shape_invariants is None)) for x in real_vars
2119      ]
2120      for x in enter_vars:
2121        x.graph.prevent_feeding(x)
2122        if self._outer_context:
2123          self._outer_context.AddInnerOp(x.op)
2124
2125    # Finds the closest enclosing non-None control pivot.
2126    outer_context = self._outer_context
2127    control_pivot = None
2128    while outer_context is not None and control_pivot is None:
2129      control_pivot = outer_context.GetControlPivot()
2130      # pylint: disable=protected-access
2131      outer_context = outer_context._outer_context
2132      # pylint: enable=protected-access
2133
2134    if control_pivot is not None:
2135      for var in enter_vars:
2136        if util.IsLoopConstantEnter(var.op.inputs[0].op):
2137          # pylint: disable=protected-access
2138          var.op._add_control_input(control_pivot.op)
2139          # pylint: enable=protected-access
2140    _SetShapeInvariants(real_vars, enter_vars, shape_invariants)
2141
2142    # Fix the control inputs and control flow context of these enter ops.
2143    self._FixControlInputsAndContext(enter_vars)
2144    self._InitializeValues(enter_vars)
2145    self._loop_enters = enter_vars
2146
2147    merge_vars = [merge([x, x])[0] for x in enter_vars]
2148    self._pivot_for_pred = merge_vars[0]
2149
2150    # Build the graph for pred.
2151    merge_vars_with_tensor_arrays = (
2152        _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars))
2153    packed_vars = nest.pack_sequence_as(
2154        structure=original_loop_vars,
2155        flat_sequence=merge_vars_with_tensor_arrays,
2156        expand_composites=True)
2157    c = ops.convert_to_tensor(pred(*packed_vars))
2158    self._pivot = loop_cond(c, name="LoopCond")
2159    switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
2160
2161    # Build the graph for body.
2162    vars_for_body = [_Identity(x[1]) for x in switch_vars]
2163    self._pivot_for_body = vars_for_body[0]
2164    # Convert TensorArray flow variables inside the context back into
2165    # their associated TensorArrays for calling the body.
2166    vars_for_body_with_tensor_arrays = (
2167        _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body))
2168    packed_vars_for_body = nest.pack_sequence_as(
2169        structure=original_loop_vars,
2170        flat_sequence=vars_for_body_with_tensor_arrays,
2171        expand_composites=True)
2172    pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2173    body_result = body(*packed_vars_for_body)
2174    post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2175    if not nest.is_sequence_or_composite(body_result):
2176      body_result = [body_result]
2177    if len(post_summaries) > len(pre_summaries):
2178      new_summaries = post_summaries[len(pre_summaries):]
2179      summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2180      summary_ref[:] = pre_summaries
2181      with ops.control_dependencies(new_summaries):
2182
2183        def map_fn(x):
2184          # TODO(apassos) figure out how to trigger with tensor arrays as well
2185          if isinstance(x, tensor_array_ops.TensorArray):
2186            return x
2187          return array_ops.identity(x)
2188
2189        body_result = nest.map_structure(
2190            map_fn, body_result, expand_composites=True)
2191
2192    # Compare the structure types of input and output of body.
2193    # For backwards compatibility, the first layer is forced to a list
2194    # during this comparison, because inputs are typically lists and
2195    # outputs of the body are typically tuples.
2196    nest.assert_same_structure(
2197        list(packed_vars_for_body), list(body_result), expand_composites=True)
2198
2199    # Store body_result to keep track of TensorArrays returned by body
2200    original_body_result = body_result
2201    # Convert TensorArrays returned by body into their flow variables
2202    result = nest.map_structure(
2203        _convert_tensorarray_to_flow,
2204        nest.flatten(body_result, expand_composites=True),
2205        expand_composites=True)
2206    result = ops.convert_n_to_tensor_or_composite(result)
2207
2208    # Add NextIteration and the back edges to complete the loop.
2209    if len(merge_vars) != len(result):
2210      raise ValueError("Number of inputs and outputs of body must match "
2211                       "loop_vars: %d, %d" % (len(merge_vars), len(result)))
2212    next_vars = []
2213    for m, v in zip(merge_vars, result):
2214      next_vars.append(_AddNextAndBackEdge(m, v))
2215
2216    # Add the exit ops.
2217    exit_vars = [exit(x[0]) for x in switch_vars]
2218    self._loop_exits = exit_vars
2219
2220    # Exit the loop.
2221    self.ExitResult(exit_vars)
2222
2223    return original_body_result, exit_vars
2224
2225  def BuildLoop(self, pred, body, loop_vars, shape_invariants,
2226                return_same_structure):
2227    """Add the loop termination condition and body to the graph."""
2228
2229    # Keep original_loop_vars to identify which are TensorArrays
2230    original_loop_vars = loop_vars
2231    # Convert TensorArrays to their flow variables
2232    loop_vars = nest.map_structure(
2233        _convert_tensorarray_to_flow,
2234        nest.flatten(loop_vars, expand_composites=False),
2235        expand_composites=True)
2236    loop_vars = ops.convert_n_to_tensor_or_composite(loop_vars)
2237    if shape_invariants is None:
2238      shape_invariants = nest.map_structure(
2239          _get_shape_invariant, loop_vars, expand_composites=False)
2240    loop_vars = nest.flatten(loop_vars, expand_composites=True)
2241    try:
2242      self.Enter()
2243      # _BuildLoop calls _update_input in several places. _mutation_lock()
2244      # ensures a Session.run call cannot occur between creating and mutating
2245      # new ops.
2246      with ops.get_default_graph()._mutation_lock():  # pylint: disable=protected-access
2247        original_body_result, exit_vars = self._BuildLoop(
2248            pred, body, original_loop_vars, loop_vars, shape_invariants)
2249    finally:
2250      self.Exit()
2251
2252    flat_result = nest.flatten(original_body_result, expand_composites=True)
2253    # Convert TensorArray flow variables outside the context back into
2254    # their associated TensorArrays for returning to caller.
2255    exit_vars_with_tensor_arrays = (
2256        _convert_flows_to_tensorarrays(flat_result, exit_vars))
2257    packed_exit_vars = nest.pack_sequence_as(
2258        structure=original_body_result,
2259        flat_sequence=exit_vars_with_tensor_arrays,
2260        expand_composites=True)
2261
2262    if return_same_structure:
2263      return packed_exit_vars
2264    else:
2265      return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
2266
2267  def _FixControlInputsAndContext(self, enters):
2268    graph = ops.get_default_graph()
2269    # pylint: disable=protected-access
2270    for e in enters:
2271      if isinstance(e, ops.Tensor):
2272        xs = [e]
2273      else:
2274        raise TypeError("Type %s not supported" % type(e))
2275      for x in xs:
2276        inp_op = x.op.inputs[0].op
2277        control_inputs = graph._control_dependencies_for_inputs([inp_op])
2278        outer_control_inputs = []
2279        for op in control_inputs:
2280          # We need to keep control inputs that are in any ancestor
2281          # ControlFlowContext, and within outer WhileContext.
2282          keep_as_control_input = True
2283          op_ctxt = util.GetOutputContext(op)
2284          outer_ctxt = self.outer_context
2285          outer_while_context = (None if outer_ctxt is None else
2286                                 outer_ctxt.GetWhileContext())
2287          while outer_ctxt != op_ctxt:
2288            if outer_ctxt is None or outer_ctxt == outer_while_context:
2289              keep_as_control_input = False
2290              break
2291            outer_ctxt = outer_ctxt.outer_context
2292          if keep_as_control_input:
2293            outer_control_inputs.append(op)
2294        x.op._set_control_flow_context(self)
2295        x.op._add_control_inputs(outer_control_inputs)
2296        graph._record_op_seen_by_control_dependencies(x.op)
2297    # pylint: enable=protected-access
2298
2299  def IsWhileContext(self):
2300    return True
2301
2302
2303# @TODO(b/133606651) Replace "shape_invariants" with "loop_vars_signature".
2304# pylint: disable=redefined-outer-name
2305@tf_export("while_loop", v1=[])
2306@deprecation.deprecated_arg_values(
2307    None,
2308    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
2309Instead of:
2310results = tf.while_loop(c, b, vars, back_prop=False)
2311Use:
2312results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))""",
2313    warn_once=True,
2314    back_prop=False)
2315def while_loop_v2(cond,
2316                  body,
2317                  loop_vars,
2318                  shape_invariants=None,
2319                  parallel_iterations=10,
2320                  back_prop=True,
2321                  swap_memory=False,
2322                  maximum_iterations=None,
2323                  name=None):
2324  """Repeat `body` while the condition `cond` is true.
2325
2326  `cond` is a callable returning a boolean scalar tensor. `body` is a callable
2327  returning a (possibly nested) tuple, namedtuple or list of tensors of the same
2328  arity (length and structure) and types as `loop_vars`. `loop_vars` is a
2329  (possibly nested) tuple, namedtuple or list of tensors that is passed to both
2330  `cond` and `body`. `cond` and `body` both take as many arguments as there are
2331  `loop_vars`.
2332
2333  In addition to regular Tensors or IndexedSlices, the body may accept and
2334  return TensorArray objects.  The flows of the TensorArray objects will
2335  be appropriately forwarded between loops and during gradient calculations.
2336
2337  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
2338  call to `while_loop`, and not at all during `Session.run()`). `while_loop`
2339  stitches together the graph fragments created during the `cond` and `body`
2340  calls with some additional graph nodes to create the graph flow that
2341  repeats `body` until `cond` returns false.
2342
2343  For correctness, `tf.while_loop()` strictly enforces shape invariants for
2344  the loop variables. A shape invariant is a (possibly partial) shape that
2345  is unchanged across the iterations of the loop. An error will be raised
2346  if the shape of a loop variable after an iteration is determined to be more
2347  general than or incompatible with its shape invariant. For example, a shape
2348  of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
2349  compatible with [11, 17]. By default (if the argument `shape_invariants` is
2350  not specified), it is assumed that the initial shape of each tensor in
2351  `loop_vars` is the same in every iteration. The `shape_invariants` argument
2352  allows the caller to specify a less specific shape invariant for each loop
2353  variable, which is needed if the shape varies between iterations. The
2354  `tf.Tensor.set_shape`
2355  function may also be used in the `body` function to indicate that
2356  the output loop variable has a particular shape. The shape invariant for
2357  SparseTensor and IndexedSlices are treated specially as follows:
2358
2359  a) If a loop variable is a SparseTensor, the shape invariant must be
2360  TensorShape([r]) where r is the rank of the dense tensor represented
2361  by the sparse tensor. It means the shapes of the three tensors of the
2362  SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
2363  is the shape of the SparseTensor.dense_shape property. It must be the shape of
2364  a vector.
2365
2366  b) If a loop variable is an IndexedSlices, the shape invariant must be
2367  a shape invariant of the values tensor of the IndexedSlices. It means
2368  the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
2369  [shape.ndims]).
2370
2371  `while_loop` implements non-strict semantics, enabling multiple iterations
2372  to run in parallel. The maximum number of parallel iterations can be
2373  controlled by `parallel_iterations`, which gives users some control over
2374  memory consumption and execution order. For correct programs, `while_loop`
2375  should return the same result for any parallel_iterations > 0.
2376
2377  For training, TensorFlow stores the tensors that are produced in the
2378  forward inference and are needed in back propagation. These tensors are a
2379  main source of memory consumption and often cause OOM errors when training
2380  on GPUs. When the flag swap_memory is true, we swap out these tensors from
2381  GPU to CPU. This for example allows us to train RNN models with very long
2382  sequences and large batches.
2383
2384  Args:
2385    cond: A callable that represents the termination condition of the loop.
2386    body: A callable that represents the loop body.
2387    loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
2388      `Tensor`, and `TensorArray` objects.
2389    shape_invariants: The shape invariants for the loop variables.
2390    parallel_iterations: The number of iterations allowed to run in parallel. It
2391      must be a positive integer.
2392    back_prop: (optional) Deprecated. False disables support for back
2393      propagation. Prefer using `tf.stop_gradient` instead.
2394    swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
2395    maximum_iterations: Optional maximum number of iterations of the while loop
2396      to run.  If provided, the `cond` output is AND-ed with an additional
2397      condition ensuring the number of iterations executed is no greater than
2398      `maximum_iterations`.
2399    name: Optional name prefix for the returned tensors.
2400
2401  Returns:
2402    The output tensors for the loop variables after the loop. The return value
2403      has the same structure as `loop_vars`.
2404
2405  Raises:
2406    TypeError: if `cond` or `body` is not callable.
2407    ValueError: if `loop_vars` is empty.
2408
2409  Example:
2410
2411  ```python
2412  i = tf.constant(0)
2413  c = lambda i: tf.less(i, 10)
2414  b = lambda i: (tf.add(i, 1), )
2415  r = tf.while_loop(c, b, [i])
2416  ```
2417
2418  Example with nesting and a namedtuple:
2419
2420  ```python
2421  import collections
2422  Pair = collections.namedtuple('Pair', 'j, k')
2423  ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
2424  c = lambda i, p: i < 10
2425  b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
2426  ijk_final = tf.while_loop(c, b, ijk_0)
2427  ```
2428
2429  Example using shape_invariants:
2430
2431  ```python
2432  i0 = tf.constant(0)
2433  m0 = tf.ones([2, 2])
2434  c = lambda i, m: i < 10
2435  b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
2436  tf.while_loop(
2437      c, b, loop_vars=[i0, m0],
2438      shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
2439  ```
2440
2441  Example which demonstrates non-strict semantics: In the following
2442  example, the final value of the counter `i` does not depend on `x`. So
2443  the `while_loop` can increment the counter parallel to updates of `x`.
2444  However, because the loop counter at one loop iteration depends
2445  on the value at the previous iteration, the loop counter itself cannot
2446  be incremented in parallel. Hence if we just want the final value of the
2447  counter (which we print on the line `print(sess.run(i))`), then
2448  `x` will never be incremented, but the counter will be updated on a
2449  single thread. Conversely, if we want the value of the output (which we
2450  print on the line `print(sess.run(out).shape)`), then the counter may be
2451  incremented on its own thread, while `x` can be incremented in
2452  parallel on a separate thread. In the extreme case, it is conceivable
2453  that the thread incrementing the counter runs until completion before
2454  `x` is incremented even a single time. The only thing that can never
2455  happen is that the thread updating `x` can never get ahead of the
2456  counter thread because the thread incrementing `x` depends on the value
2457  of the counter.
2458
2459  ```python
2460  import tensorflow as tf
2461
2462  n = 10000
2463  x = tf.constant(list(range(n)))
2464  c = lambda i, x: i < n
2465  b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
2466  [i], "x:"))
2467  i, out = tf.while_loop(c, b, (0, x))
2468  with tf.compat.v1.Session() as sess:
2469      print(sess.run(i))  # prints [0] ... [9999]
2470
2471      # The following line may increment the counter and x in parallel.
2472      # The counter thread may get ahead of the other thread, but not the
2473      # other way around. So you may see things like
2474      # [9996] x:[9987]
2475      # meaning that the counter thread is on iteration 9996,
2476      # while the other thread is on iteration 9987
2477      print(sess.run(out).shape)
2478  ```
2479
2480  """
2481  return while_loop(
2482      cond=cond,
2483      body=body,
2484      loop_vars=loop_vars,
2485      shape_invariants=shape_invariants,
2486      parallel_iterations=parallel_iterations,
2487      back_prop=back_prop,
2488      swap_memory=swap_memory,
2489      name=name,
2490      maximum_iterations=maximum_iterations,
2491      return_same_structure=True)
2492
2493
2494# pylint: disable=redefined-outer-name
2495@tf_export(v1=["while_loop"])
2496def while_loop(cond,
2497               body,
2498               loop_vars,
2499               shape_invariants=None,
2500               parallel_iterations=10,
2501               back_prop=True,
2502               swap_memory=False,
2503               name=None,
2504               maximum_iterations=None,
2505               return_same_structure=False):
2506  """Repeat `body` while the condition `cond` is true.
2507
2508  `cond` is a callable returning a boolean scalar tensor. `body` is a callable
2509  returning a (possibly nested) tuple, namedtuple or list of tensors of the same
2510  arity (length and structure) and types as `loop_vars`. `loop_vars` is a
2511  (possibly nested) tuple, namedtuple or list of tensors that is passed to both
2512  `cond` and `body`. `cond` and `body` both take as many arguments as there are
2513  `loop_vars`.
2514
2515  In addition to regular Tensors or IndexedSlices, the body may accept and
2516  return TensorArray objects.  The flows of the TensorArray objects will
2517  be appropriately forwarded between loops and during gradient calculations.
2518
2519  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
2520  call to `while_loop`, and not at all during `Session.run()`). `while_loop`
2521  stitches together the graph fragments created during the `cond` and `body`
2522  calls with some additional graph nodes to create the graph flow that
2523  repeats `body` until `cond` returns false.
2524
2525  For correctness, `tf.while_loop()` strictly enforces shape invariants for
2526  the loop variables. A shape invariant is a (possibly partial) shape that
2527  is unchanged across the iterations of the loop. An error will be raised
2528  if the shape of a loop variable after an iteration is determined to be more
2529  general than or incompatible with its shape invariant. For example, a shape
2530  of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
2531  compatible with [11, 17]. By default (if the argument `shape_invariants` is
2532  not specified), it is assumed that the initial shape of each tensor in
2533  `loop_vars` is the same in every iteration. The `shape_invariants` argument
2534  allows the caller to specify a less specific shape invariant for each loop
2535  variable, which is needed if the shape varies between iterations. The
2536  `tf.Tensor.set_shape`
2537  function may also be used in the `body` function to indicate that
2538  the output loop variable has a particular shape. The shape invariant for
2539  SparseTensor and IndexedSlices are treated specially as follows:
2540
2541  a) If a loop variable is a SparseTensor, the shape invariant must be
2542  TensorShape([r]) where r is the rank of the dense tensor represented
2543  by the sparse tensor. It means the shapes of the three tensors of the
2544  SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
2545  is the shape of the SparseTensor.dense_shape property. It must be the shape of
2546  a vector.
2547
2548  b) If a loop variable is an IndexedSlices, the shape invariant must be
2549  a shape invariant of the values tensor of the IndexedSlices. It means
2550  the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
2551  [shape.ndims]).
2552
2553  `while_loop` implements non-strict semantics, enabling multiple iterations
2554  to run in parallel. The maximum number of parallel iterations can be
2555  controlled by `parallel_iterations`, which gives users some control over
2556  memory consumption and execution order. For correct programs, `while_loop`
2557  should return the same result for any parallel_iterations > 0.
2558
2559  For training, TensorFlow stores the tensors that are produced in the
2560  forward inference and are needed in back propagation. These tensors are a
2561  main source of memory consumption and often cause OOM errors when training
2562  on GPUs. When the flag swap_memory is true, we swap out these tensors from
2563  GPU to CPU. This for example allows us to train RNN models with very long
2564  sequences and large batches.
2565
2566  Args:
2567    cond: A callable that represents the termination condition of the loop.
2568    body: A callable that represents the loop body.
2569    loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
2570      `Tensor`, and `TensorArray` objects.
2571    shape_invariants: The shape invariants for the loop variables.
2572    parallel_iterations: The number of iterations allowed to run in parallel. It
2573      must be a positive integer.
2574    back_prop: Whether backprop is enabled for this while loop.
2575    swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
2576    name: Optional name prefix for the returned tensors.
2577    maximum_iterations: Optional maximum number of iterations of the while loop
2578      to run.  If provided, the `cond` output is AND-ed with an additional
2579      condition ensuring the number of iterations executed is no greater than
2580      `maximum_iterations`.
2581    return_same_structure: If True, output has same structure as `loop_vars`. If
2582      eager execution is enabled, this is ignored (and always treated as True).
2583
2584  Returns:
2585    The output tensors for the loop variables after the loop.
2586     If `return_same_structure` is True, the return value has the same
2587     structure as `loop_vars`.
2588     If `return_same_structure` is False, the return value is a Tensor,
2589     TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list
2590     otherwise.
2591
2592  Raises:
2593    TypeError: if `cond` or `body` is not callable.
2594    ValueError: if `loop_vars` is empty.
2595
2596  Example:
2597
2598  ```python
2599  i = tf.constant(0)
2600  c = lambda i: tf.less(i, 10)
2601  b = lambda i: tf.add(i, 1)
2602  r = tf.while_loop(c, b, [i])
2603  ```
2604
2605  Example with nesting and a namedtuple:
2606
2607  ```python
2608  import collections
2609  Pair = collections.namedtuple('Pair', 'j, k')
2610  ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
2611  c = lambda i, p: i < 10
2612  b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
2613  ijk_final = tf.while_loop(c, b, ijk_0)
2614  ```
2615
2616  Example using shape_invariants:
2617
2618  ```python
2619  i0 = tf.constant(0)
2620  m0 = tf.ones([2, 2])
2621  c = lambda i, m: i < 10
2622  b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
2623  tf.while_loop(
2624      c, b, loop_vars=[i0, m0],
2625      shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
2626  ```
2627
2628  Example which demonstrates non-strict semantics: In the following
2629  example, the final value of the counter `i` does not depend on `x`. So
2630  the `while_loop` can increment the counter parallel to updates of `x`.
2631  However, because the loop counter at one loop iteration depends
2632  on the value at the previous iteration, the loop counter itself cannot
2633  be incremented in parallel. Hence if we just want the final value of the
2634  counter (which we print on the line `print(sess.run(i))`), then
2635  `x` will never be incremented, but the counter will be updated on a
2636  single thread. Conversely, if we want the value of the output (which we
2637  print on the line `print(sess.run(out).shape)`), then the counter may be
2638  incremented on its own thread, while `x` can be incremented in
2639  parallel on a separate thread. In the extreme case, it is conceivable
2640  that the thread incrementing the counter runs until completion before
2641  `x` is incremented even a single time. The only thing that can never
2642  happen is that the thread updating `x` can never get ahead of the
2643  counter thread because the thread incrementing `x` depends on the value
2644  of the counter.
2645
2646  ```python
2647  import tensorflow as tf
2648
2649  n = 10000
2650  x = tf.constant(list(range(n)))
2651  c = lambda i, x: i < n
2652  b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
2653  [i], "x:"))
2654  i, out = tf.while_loop(c, b, (0, x))
2655  with tf.compat.v1.Session() as sess:
2656      print(sess.run(i))  # prints [0] ... [9999]
2657
2658      # The following line may increment the counter and x in parallel.
2659      # The counter thread may get ahead of the other thread, but not the
2660      # other way around. So you may see things like
2661      # [9996] x:[9987]
2662      # meaning that the counter thread is on iteration 9996,
2663      # while the other thread is on iteration 9987
2664      print(sess.run(out).shape)
2665  ```
2666
2667  """
2668  if not callable(cond):
2669    raise TypeError("cond must be callable.")
2670  if not callable(body):
2671    raise TypeError("body must be callable.")
2672  if parallel_iterations < 1:
2673    raise TypeError("parallel_iterations must be a positive integer.")
2674
2675  # Always enable control flow v2 if building a function, regardless of toggle.
2676  executing_eagerly = context.executing_eagerly()
2677  if (util.EnableControlFlowV2(ops.get_default_graph()) and
2678      not executing_eagerly):
2679    return while_v2.while_loop(
2680        cond,
2681        body,
2682        loop_vars,
2683        shape_invariants=shape_invariants,
2684        parallel_iterations=parallel_iterations,
2685        maximum_iterations=maximum_iterations,
2686        name=name,
2687        return_same_structure=return_same_structure,
2688        back_prop=back_prop)
2689
2690  with ops.name_scope(name, "while", loop_vars):
2691    if not loop_vars:
2692      raise ValueError("No loop variables provided")
2693    try_to_pack = (len(loop_vars) == 1 and not return_same_structure)
2694    if maximum_iterations is not None:
2695      maximum_iterations = ops.convert_to_tensor(
2696          maximum_iterations, name="maximum_iterations")
2697      if maximum_iterations.shape.ndims != 0:
2698        raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
2699                         maximum_iterations.shape)
2700
2701      if executing_eagerly:
2702        counter = 0
2703        maximum_iterations = int(maximum_iterations.numpy())
2704      else:
2705        counter = constant_op.constant(
2706            0, dtype=maximum_iterations.dtype, name="iteration_counter")
2707      orig_cond = cond
2708      orig_body = body
2709      if try_to_pack:
2710        loop_vars = (counter, loop_vars[0])
2711        cond = lambda i, lv: (  # pylint: disable=g-long-lambda
2712            math_ops.logical_and(i < maximum_iterations, orig_cond(lv)))
2713        body = lambda i, lv: (i + 1, orig_body(lv))
2714      else:
2715        loop_vars = (counter, loop_vars)
2716        cond = lambda i, lv: (  # pylint: disable=g-long-lambda
2717            math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
2718        body = lambda i, lv: (i + 1, orig_body(*lv))
2719      try_to_pack = False
2720
2721    if executing_eagerly:
2722      packed = False  # whether the body result was packed into a 1-item tuple
2723
2724      loop_var_structure = nest.map_structure(type_spec.type_spec_from_value,
2725                                              list(loop_vars))
2726      while cond(*loop_vars):
2727        loop_vars = body(*loop_vars)
2728        if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2729          packed = True
2730          loop_vars = (loop_vars,)
2731        nest.assert_same_structure(loop_var_structure, list(loop_vars))
2732
2733      def convert(x):
2734        if isinstance(x, tensor_array_ops.TensorArray):
2735          return x
2736        return ops.convert_to_tensor(x)
2737
2738      loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True)
2739      if maximum_iterations is not None:
2740        return loop_vars[1]
2741      else:
2742        return loop_vars[0] if packed else loop_vars
2743
2744    if shape_invariants is not None:
2745      if maximum_iterations is not None:
2746        shape_invariants = (tensor_shape.TensorShape([]), shape_invariants)
2747
2748      nest.assert_same_structure(
2749          loop_vars, shape_invariants, expand_composites=False)
2750      shape_invariants = nest.map_structure(
2751          _get_shape_invariant,
2752          loop_vars,
2753          shape_invariants,
2754          expand_composites=False)
2755
2756    loop_context = WhileContext(
2757        maximum_iterations=maximum_iterations,
2758        parallel_iterations=parallel_iterations,
2759        back_prop=back_prop,
2760        swap_memory=swap_memory)
2761    # Only add non-nested loops to the collection. Any nested control flow will
2762    # be encapsulated in the root context.
2763    if loop_context.outer_context is None:
2764      ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
2765    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
2766                                    return_same_structure)
2767    if maximum_iterations is not None:
2768      return result[1]
2769    else:
2770      return result
2771
2772
2773# pylint: enable=redefined-outer-name
2774
2775
2776def _AsTensorList(x, p):
2777  """Return x as a list of Tensors or IndexedSlices.
2778
2779  For entries of `x` that are Operations, this returns an Identity of `p`
2780  with a dependency on the operation.
2781
2782  Args:
2783    x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
2784    p: A Tensor to return for entries in `x` that are Operations.
2785
2786  Returns:
2787    A list of Tensors or IndexedSlices.
2788  """
2789  if not isinstance(x, (list, _basetuple)):
2790    x = [x]
2791
2792  l = []
2793  for v in x:
2794    if isinstance(v, ops.Operation):
2795      v = with_dependencies([v], p)
2796    v = ops.convert_to_tensor_or_composite(v)
2797    if isinstance(v, ops.Tensor):
2798      l.append(array_ops.identity(v))
2799    else:
2800      l.append(
2801          ops.IndexedSlices(
2802              array_ops.identity(v.values), array_ops.identity(v.indices)))
2803  return l
2804
2805
2806def _CheckResults(a, b):
2807  assert len(a) == len(b), (
2808      "Values returned by a() and b() must have the same length.")
2809  for x, y in zip(a, b):
2810    assert x.dtype == y.dtype, (
2811        "Values returned by a() [%s] and b() [%s] must have "
2812        "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
2813
2814
2815def with_dependencies(dependencies, output_tensor, name=None):
2816  """Produces the content of `output_tensor` only after `dependencies`.
2817
2818  In some cases, a user may want the output of an operation to be
2819  consumed externally only after some other dependencies have run
2820  first. This function ensures returns `output_tensor`, but only after all
2821  operations in `dependencies` have run. Note that this means that there is
2822  no guarantee that `output_tensor` will be evaluated after any `dependencies`
2823  have run.
2824
2825  See also `tf.tuple` and `tf.group`.
2826
2827  Args:
2828    dependencies: Iterable of operations to run before this op finishes.
2829    output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
2830    name: (Optional) A name for this operation.
2831
2832  Returns:
2833    Same as `output_tensor`.
2834
2835  Raises:
2836    TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
2837  """
2838  if context.executing_eagerly():
2839    return output_tensor
2840  with ops.name_scope(name, "control_dependency",
2841                      list(dependencies) + [output_tensor]) as name:
2842    with ops.colocate_with(output_tensor):
2843      with ops.control_dependencies(dependencies):
2844        output_tensor = ops.convert_to_tensor_or_composite(output_tensor)
2845        if isinstance(output_tensor, ops.Tensor):
2846          return _Identity(output_tensor, name=name)
2847        else:
2848          return ops.IndexedSlices(
2849              _Identity(output_tensor.values, name=name), output_tensor.indices,
2850              output_tensor.dense_shape)
2851
2852
2853def _GroupControlDeps(dev, deps, name=None):
2854  with ops.control_dependencies(deps):
2855    if dev is None:
2856      return no_op(name=name)
2857    else:
2858      with ops.device(dev):
2859        return no_op(name=name)
2860
2861
2862# TODO(touts): Accept "inputs" as a list.
2863@tf_export("group")
2864def group(*inputs, **kwargs):
2865  """Create an op that groups multiple operations.
2866
2867  When this op finishes, all ops in `inputs` have finished. This op has no
2868  output.
2869
2870  See also `tf.tuple` and
2871  `tf.control_dependencies`.
2872
2873  Args:
2874    *inputs: Zero or more tensors to group.
2875    name: A name for this operation (optional).
2876
2877  Returns:
2878    An Operation that executes all its inputs.
2879
2880  Raises:
2881    ValueError: If an unknown keyword argument is provided.
2882  """
2883  if context.executing_eagerly():
2884    return None
2885  name = kwargs.pop("name", None)
2886  if kwargs:
2887    raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
2888  with ops.name_scope(name, "group_deps", inputs) as name:
2889    # Grouping no inputs means do nothing
2890    if not inputs:
2891      return no_op(name=name)
2892
2893    # Sorts *inputs according to their devices.
2894    ops_on_device = {}  # device -> operations specified on the device.
2895    for inp in nest.flatten(inputs, expand_composites=True):
2896      if not hasattr(inp, "device"):
2897        raise TypeError("Expected tf.group() expected Tensor arguments not "
2898                        "'%s' with type '%s'" % (inp, type(inp)))
2899      dev = inp.device
2900      if dev in ops_on_device:
2901        ops_on_device[dev].append(inp)
2902      else:
2903        ops_on_device[dev] = [inp]
2904    if len(ops_on_device) == 1:
2905      # 1-level tree. The root node is the returned NoOp node.
2906      (dev, deps), = ops_on_device.items()
2907      return _GroupControlDeps(dev, deps, name=name)
2908
2909    # 2-level tree. The root node is the returned NoOp node.
2910    # deps contains 1 NoOp node for each device.
2911    deps = []
2912
2913    def device_key(dev):
2914      """A sort key that allows None to be compared to strings."""
2915      return "" if dev is None else dev
2916
2917    for dev in sorted(ops_on_device, key=device_key):
2918      deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
2919
2920    with ops.control_dependencies(deps):
2921      return no_op(name=name)
2922
2923
2924@tf_export("tuple", v1=[])
2925def tuple_v2(tensors, control_inputs=None, name=None):
2926  """Group tensors together.
2927
2928  This creates a tuple of tensors with the same values as the `tensors`
2929  argument, except that the value of each tensor is only returned after the
2930  values of all tensors have been computed.
2931
2932  `control_inputs` contains additional ops that have to finish before this op
2933  finishes, but whose outputs are not returned.
2934
2935  This can be used as a "join" mechanism for parallel computations: all the
2936  argument tensors can be computed in parallel, but the values of any tensor
2937  returned by `tuple` are only available after all the parallel computations
2938  are done.
2939
2940  See also `tf.group` and
2941  `tf.control_dependencies`.
2942
2943  Args:
2944    tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
2945    control_inputs: List of additional ops to finish before returning.
2946    name: (optional) A name to use as a `name_scope` for the operation.
2947
2948  Returns:
2949    Same as `tensors`.
2950
2951  Raises:
2952    ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
2953    TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
2954      objects.
2955
2956  """
2957  return tuple(tensors=tensors, name=name, control_inputs=control_inputs)  # pylint: disable=redefined-builtin
2958
2959
2960@tf_export(v1=["tuple"])
2961def tuple(tensors, name=None, control_inputs=None):  # pylint: disable=redefined-builtin
2962  """Group tensors together.
2963
2964  This creates a tuple of tensors with the same values as the `tensors`
2965  argument, except that the value of each tensor is only returned after the
2966  values of all tensors have been computed.
2967
2968  `control_inputs` contains additional ops that have to finish before this op
2969  finishes, but whose outputs are not returned.
2970
2971  This can be used as a "join" mechanism for parallel computations: all the
2972  argument tensors can be computed in parallel, but the values of any tensor
2973  returned by `tuple` are only available after all the parallel computations
2974  are done.
2975
2976  See also `tf.group` and
2977  `tf.control_dependencies`.
2978
2979  Args:
2980    tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
2981    name: (optional) A name to use as a `name_scope` for the operation.
2982    control_inputs: List of additional ops to finish before returning.
2983
2984  Returns:
2985    Same as `tensors`.
2986
2987  Raises:
2988    ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
2989    TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
2990      objects.
2991
2992  """
2993  if context.executing_eagerly():
2994    return tensors
2995  with ops.name_scope(name, "tuple", tensors) as name:
2996    tensors = [
2997        t if (isinstance(t, ops.Operation) or tensor_util.is_tensor(t) or
2998              t is None) else ops.convert_to_tensor(t) for t in tensors
2999    ]
3000    gating_ops = [
3001        t if isinstance(t, ops.Operation) else t.op
3002        for t in tensors
3003        if t is not None
3004    ]
3005    if control_inputs:
3006      for c in control_inputs:
3007        if isinstance(c, ops.Tensor):
3008          c = c.op
3009        elif not isinstance(c, ops.Operation):
3010          raise TypeError("Control input must be Operation or Tensor: %s" % c)
3011        gating_ops.append(c)
3012    # Note that in order to ensure ordering in the pbtxt, we must take care to
3013    # ensure the order here.
3014    gating_ops = sorted(set(gating_ops), key=lambda op: op._id)  # Uniquify ops.
3015    if not gating_ops:
3016      raise ValueError("Must have at least one Tensor: %s" % tensors)
3017    gate = group(*gating_ops)
3018    tpl = []
3019    for t in tensors:
3020      if tensor_util.is_tensor(t):
3021        tpl.append(with_dependencies([gate], t))
3022      elif isinstance(t, ops.Operation):
3023        with ops.control_dependencies([gate]):
3024          tpl.append(group(t))
3025      else:
3026        tpl.append(None)
3027    return tpl
3028
3029
3030def _assert_at_most_n_true(predicates, n, msg):
3031  """Returns an Assert op that checks that at most n predicates are True.
3032
3033  Args:
3034    predicates: list of bool scalar tensors.
3035    n: maximum number of true predicates allowed.
3036    msg: Error message.
3037  """
3038  preds_c = array_ops.stack(predicates, name="preds_c")
3039  num_true_conditions = math_ops.reduce_sum(
3040      math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
3041  condition = math_ops.less_equal(num_true_conditions,
3042                                  constant_op.constant(n, name="n_true_conds"))
3043  preds_names = ", ".join(getattr(p, "name", "?") for p in predicates)
3044  error_msg = [
3045      "%s: more than %d conditions (%s) evaluated as True:" %
3046      (msg, n, preds_names), preds_c
3047  ]
3048  return Assert(condition, data=error_msg, summarize=len(predicates))
3049
3050
3051def _case_create_default_action(predicates, actions):
3052  """Creates default action for a list of actions and their predicates.
3053
3054  It uses the input actions to select an arbitrary as default and makes sure
3055  that corresponding predicates have valid values.
3056
3057  Args:
3058    predicates: a list of bool scalar tensors
3059    actions: a list of callable objects which return tensors.
3060
3061  Returns:
3062    a callable
3063  """
3064  k = len(predicates) - 1  # could pick any
3065  predicate, action = predicates[k], actions[k]
3066  other_predicates, other_actions = predicates[:k], actions[:k]
3067
3068  def default_action():
3069    others_msg = ("Implementation error: "
3070                  "selected default action #%d was called, but some of other "
3071                  "predicates are True: " % k)
3072    default_msg = ("Input error: "
3073                   "None of conditions evaluated as True:",
3074                   array_ops.stack(predicates, name="preds_c"))
3075    with ops.control_dependencies([
3076        _assert_at_most_n_true(other_predicates, n=0, msg=others_msg),
3077        Assert(predicate, data=default_msg)
3078    ]):
3079      return action()
3080
3081  return default_action, other_predicates, other_actions
3082
3083
3084def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name,
3085                                       allow_python_preds):
3086  """Verifies input arguments for the case function.
3087
3088  Args:
3089    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
3090      callable which returns a list of tensors.
3091    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3092    name: A name for the case operation.
3093    allow_python_preds: if true, pred_fn_pairs may contain Python bools in
3094      addition to boolean Tensors
3095
3096  Raises:
3097    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3098    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3099    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3100               callable.
3101
3102  Returns:
3103    a tuple <list of scalar bool tensors, list of callables>.
3104  """
3105  if not isinstance(pred_fn_pairs, (list, _basetuple, dict)):
3106    raise TypeError("fns must be a list, tuple, or dict")
3107
3108  if isinstance(pred_fn_pairs, collections.OrderedDict):
3109    pred_fn_pairs = pred_fn_pairs.items()
3110  elif isinstance(pred_fn_pairs, dict):
3111    if context.executing_eagerly():
3112      # No name to sort on in eager mode. Use dictionary traversal order,
3113      # which is nondeterministic in versions of Python < 3.6
3114      if not exclusive:
3115        raise ValueError("Unordered dictionaries are not supported for the "
3116                         "`pred_fn_pairs` argument when `exclusive=False` and "
3117                         "eager mode is enabled.")
3118      pred_fn_pairs = list(pred_fn_pairs.items())
3119    else:
3120      pred_fn_pairs = sorted(
3121          pred_fn_pairs.items(), key=lambda item: item[0].name)
3122      if not exclusive:
3123        logging.warn(
3124            "%s: An unordered dictionary of predicate/fn pairs was "
3125            "provided, but exclusive=False. The order of conditional "
3126            "tests is deterministic but not guaranteed.", name)
3127  for pred_fn_pair in pred_fn_pairs:
3128    if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2:
3129      raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
3130    pred, fn = pred_fn_pair
3131
3132    if isinstance(pred, ops.Tensor):
3133      if pred.dtype != dtypes.bool:
3134        raise TypeError("pred must be Tensor of type bool: %s" % pred.name)
3135    elif not allow_python_preds:
3136      raise TypeError("pred must be a Tensor, got: %s" % pred)
3137    elif not isinstance(pred, bool):
3138      raise TypeError("pred must be a Tensor or bool, got: %s" % pred)
3139
3140    if not callable(fn):
3141      raise TypeError("fn for pred %s must be callable." % pred.name)
3142
3143  predicates, actions = zip(*pred_fn_pairs)
3144  return predicates, actions
3145
3146
3147def _case_helper(cond_fn,
3148                 pred_fn_pairs,
3149                 default,
3150                 exclusive,
3151                 name,
3152                 allow_python_preds=False,
3153                 **cond_kwargs):
3154  """Implementation of case that allows for different cond functions.
3155
3156  Args:
3157    cond_fn: method that has signature and semantics of `cond` above.
3158    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
3159      callable which returns a list of tensors.
3160    default: Optional callable that returns a list of tensors.
3161    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3162    name: A name for this operation (optional).
3163    allow_python_preds: if true, pred_fn_pairs may contain Python bools in
3164      addition to boolean Tensors
3165    **cond_kwargs: keyword arguments that will be passed to `cond_fn`.
3166
3167  Returns:
3168    The tensors returned by the first pair whose predicate evaluated to True, or
3169    those returned by `default` if none does.
3170
3171  Raises:
3172    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3173    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3174    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3175               callable.
3176  """
3177  predicates, actions = _case_verify_and_canonicalize_args(
3178      pred_fn_pairs, exclusive, name, allow_python_preds)
3179  with ops.name_scope(name, "case", [predicates]):
3180    if default is None:
3181      default, predicates, actions = _case_create_default_action(
3182          predicates, actions)
3183    fn = default
3184    # To eval conditions in direct order we create nested conditions in reverse:
3185    #   cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...))
3186    for predicate, action in reversed(list(zip(predicates, actions))):
3187      fn = functools.partial(
3188          cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs)
3189    if exclusive:
3190      with ops.control_dependencies([
3191          _assert_at_most_n_true(
3192              predicates, n=1, msg="Input error: exclusive=True")
3193      ]):
3194        return fn()
3195    else:
3196      return fn()
3197
3198
3199def _indexed_case_verify_and_canonicalize_args(branch_fns, default,
3200                                               branch_index):
3201  """Verifies input arguments for the case function.
3202
3203  Args:
3204    branch_fns: Dict or list of pairs of an `int` and a callable which
3205      returns a list of tensors.
3206    default: Optional callable that returns a list of tensors.
3207    branch_index: Optional int `Tensor`, which selects for the corresponding
3208      pred_fn_pair.
3209
3210  Raises:
3211    TypeError: If `branch_fns` is not a list/dictionary.
3212    TypeError: If `branch_fns` is a list but does not contain 2-tuples or
3213               callables.
3214    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3215               callable.
3216
3217  Returns:
3218    branch_fns: validated list of callables for each branch (default last).
3219  """
3220  if not isinstance(branch_index, ops.Tensor):
3221    raise TypeError("branch_index must a Tensor, got {}".format(
3222        type(branch_index)))
3223  if not branch_index.dtype.is_integer:
3224    raise TypeError("branch_index must an integer Tensor, got {}".format(
3225        branch_index.dtype))
3226
3227  if not branch_fns:
3228    raise ValueError("Must provide at least one item in branch_fns")
3229  if not isinstance(branch_fns, (list, _basetuple, dict)):
3230    raise TypeError("branch_fns must be a list, tuple, or dict")
3231
3232  if isinstance(branch_fns, dict):
3233    branch_fns = branch_fns.items()
3234
3235  if all(callable(fn) for fn in branch_fns):
3236    branch_fns = list(enumerate(branch_fns))
3237
3238  for key_fn_pair in branch_fns:
3239    if not isinstance(key_fn_pair, _basetuple) or len(key_fn_pair) != 2:
3240      raise TypeError("Each entry in branch_fns must be a 2-tuple")
3241    key, branch_fn = key_fn_pair
3242
3243    if not isinstance(key, int):
3244      raise TypeError("key must be a Python `int`, got {}".format(type(key)))
3245
3246    if not callable(branch_fn):
3247      raise TypeError("fn for key {} must be callable.".format(key))
3248
3249  keys = [p[0] for p in branch_fns]
3250  if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys):
3251    raise ValueError(
3252        "branch indices (keys) must form contiguous range of [0 to {}) but "
3253        "found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys)))))
3254  actions = [p[1] for p in sorted(branch_fns)]
3255  if default is not None:
3256    actions.append(default)
3257  return actions
3258
3259
3260def _indexed_case_helper(branch_fns, default, branch_index, name):
3261  """Implementation of case that emits the n-way indexed Case op.
3262
3263  Args:
3264    branch_fns: Dict or list of pairs of a boolean scalar tensor, and a
3265      callable which returns a list of tensors.
3266    default: Optional callable that returns a list of tensors.
3267    branch_index: Optional int `Tensor`, which selects for the corresponding
3268      pred_fn_pair.
3269    name: A name for this operation (optional).
3270
3271  Returns:
3272    The tensors returned by the pair whose key matched branch_index, or
3273    those returned by `default` if none does.
3274
3275  Raises:
3276    TypeError: If `branch_fns` is not a list/dictionary.
3277    TypeError: If `branch_fns` is a list but does not contain 2-tuples or
3278               callables.
3279    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3280               callable.
3281  """
3282  branch_fns = _indexed_case_verify_and_canonicalize_args(
3283      branch_fns, default, branch_index)
3284  with ops.name_scope(name, "case", [branch_index]):
3285    if context.executing_eagerly() and not hasattr(branch_index, "graph"):
3286      branch_index = array_ops.where(
3287          math_ops.less(branch_index, 0)
3288          | math_ops.greater_equal(branch_index, len(branch_fns)),
3289          len(branch_fns) - 1, branch_index)
3290      return branch_fns[int(branch_index)]()
3291    return cond_v2.indexed_case(branch_index, branch_fns)
3292
3293
3294@tf_export("case", v1=[])
3295def case_v2(pred_fn_pairs,
3296            default=None,
3297            exclusive=False,
3298            strict=False,
3299            name="case"):
3300  """Create a case operation.
3301
3302  See also `tf.switch_case`.
3303
3304  The `pred_fn_pairs` parameter is a list of pairs of size N.
3305  Each pair contains a boolean scalar tensor and a python callable that
3306  creates the tensors to be returned if the boolean evaluates to True.
3307  `default` is a callable generating a list of tensors. All the callables
3308  in `pred_fn_pairs` as well as `default` (if provided) should return the same
3309  number and types of tensors.
3310
3311  If `exclusive==True`, all predicates are evaluated, and an exception is
3312  thrown if more than one of the predicates evaluates to `True`.
3313  If `exclusive==False`, execution stops at the first predicate which
3314  evaluates to True, and the tensors generated by the corresponding function
3315  are returned immediately. If none of the predicates evaluate to True, this
3316  operation returns the tensors generated by `default`.
3317
3318  `tf.case` supports nested structures as implemented in
3319  `tf.contrib.framework.nest`. All of the callables must return the same
3320  (possibly nested) value structure of lists, tuples, and/or named tuples.
3321  Singleton lists and tuples form the only exceptions to this: when returned by
3322  a callable, they are implicitly unpacked to single values. This
3323  behavior is disabled by passing `strict=True`.
3324
3325  @compatibility(v2)
3326  `pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and
3327  tf.Variable are no longer hashable in v2, so cannot be used as a key for a
3328  dictionary.  Please use a list or a tuple instead.
3329  @end_compatibility
3330
3331
3332  **Example 1:**
3333
3334  Pseudocode:
3335
3336  ```
3337  if (x < y) return 17;
3338  else return 23;
3339  ```
3340
3341  Expressions:
3342
3343  ```python
3344  f1 = lambda: tf.constant(17)
3345  f2 = lambda: tf.constant(23)
3346  r = tf.case([(tf.less(x, y), f1)], default=f2)
3347  ```
3348
3349  **Example 2:**
3350
3351  Pseudocode:
3352
3353  ```
3354  if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
3355  if (x < y) return 17;
3356  else if (x > z) return 23;
3357  else return -1;
3358  ```
3359
3360  Expressions:
3361
3362  ```python
3363  def f1(): return tf.constant(17)
3364  def f2(): return tf.constant(23)
3365  def f3(): return tf.constant(-1)
3366  r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)],
3367           default=f3, exclusive=True)
3368  ```
3369
3370  Args:
3371    pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which
3372      returns a list of tensors.
3373    default: Optional callable that returns a list of tensors.
3374    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3375    strict: A boolean that enables/disables 'strict' mode; see above.
3376    name: A name for this operation (optional).
3377
3378  Returns:
3379    The tensors returned by the first pair whose predicate evaluated to True, or
3380    those returned by `default` if none does.
3381
3382  Raises:
3383    TypeError: If `pred_fn_pairs` is not a list/tuple.
3384    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3385    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3386               callable.
3387  """
3388  return _case_helper(
3389      cond,
3390      pred_fn_pairs,
3391      default,
3392      exclusive,
3393      name,
3394      allow_python_preds=False,
3395      strict=strict)
3396
3397
3398@tf_export(v1=["case"])
3399def case(pred_fn_pairs,
3400         default=None,
3401         exclusive=False,
3402         strict=False,
3403         name="case"):
3404  """Create a case operation.
3405
3406  See also `tf.switch_case`.
3407
3408  The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
3409  Each pair contains a boolean scalar tensor and a python callable that
3410  creates the tensors to be returned if the boolean evaluates to True.
3411  `default` is a callable generating a list of tensors. All the callables
3412  in `pred_fn_pairs` as well as `default` (if provided) should return the same
3413  number and types of tensors.
3414
3415  If `exclusive==True`, all predicates are evaluated, and an exception is
3416  thrown if more than one of the predicates evaluates to `True`.
3417  If `exclusive==False`, execution stops at the first predicate which
3418  evaluates to True, and the tensors generated by the corresponding function
3419  are returned immediately. If none of the predicates evaluate to True, this
3420  operation returns the tensors generated by `default`.
3421
3422  `tf.case` supports nested structures as implemented in
3423  `tf.contrib.framework.nest`. All of the callables must return the same
3424  (possibly nested) value structure of lists, tuples, and/or named tuples.
3425  Singleton lists and tuples form the only exceptions to this: when returned by
3426  a callable, they are implicitly unpacked to single values. This
3427  behavior is disabled by passing `strict=True`.
3428
3429  If an unordered dictionary is used for `pred_fn_pairs`, the order of the
3430  conditional tests is not guaranteed. However, the order is guaranteed to be
3431  deterministic, so that variables created in conditional branches are created
3432  in fixed order across runs.
3433
3434  @compatibility(eager)
3435  Unordered dictionaries are not supported in eager mode when `exclusive=False`.
3436  Use a list of tuples instead.
3437  @end_compatibility
3438
3439
3440  **Example 1:**
3441
3442  Pseudocode:
3443
3444  ```
3445  if (x < y) return 17;
3446  else return 23;
3447  ```
3448
3449  Expressions:
3450
3451  ```python
3452  f1 = lambda: tf.constant(17)
3453  f2 = lambda: tf.constant(23)
3454  r = tf.case([(tf.less(x, y), f1)], default=f2)
3455  ```
3456
3457  **Example 2:**
3458
3459  Pseudocode:
3460
3461  ```
3462  if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
3463  if (x < y) return 17;
3464  else if (x > z) return 23;
3465  else return -1;
3466  ```
3467
3468  Expressions:
3469
3470  ```python
3471  def f1(): return tf.constant(17)
3472  def f2(): return tf.constant(23)
3473  def f3(): return tf.constant(-1)
3474  r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
3475           default=f3, exclusive=True)
3476  ```
3477
3478  Args:
3479    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
3480      callable which returns a list of tensors.
3481    default: Optional callable that returns a list of tensors.
3482    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3483    strict: A boolean that enables/disables 'strict' mode; see above.
3484    name: A name for this operation (optional).
3485
3486  Returns:
3487    The tensors returned by the first pair whose predicate evaluated to True, or
3488    those returned by `default` if none does.
3489
3490  Raises:
3491    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3492    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3493    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3494               callable.
3495  """
3496  return _case_helper(
3497      cond,
3498      pred_fn_pairs,
3499      default,
3500      exclusive,
3501      name,
3502      allow_python_preds=False,
3503      strict=strict)
3504
3505
3506@tf_export("switch_case")
3507def switch_case(branch_index,
3508                branch_fns,
3509                default=None,
3510                name="switch_case"):
3511  """Create a switch/case operation, i.e. an integer-indexed conditional.
3512
3513  See also `tf.case`.
3514
3515  This op can be substantially more efficient than `tf.case` when exactly one
3516  branch will be selected. `tf.switch_case` is more like a C++ switch/case
3517  statement than `tf.case`, which is more like an if/elif/elif/else chain.
3518
3519  The `branch_fns` parameter is either a dict from `int` to callables, or list
3520  of (`int`, callable) pairs, or simply a list of callables (in which case the
3521  index is implicitly the key). The `branch_index` `Tensor` is used to select an
3522  element in `branch_fns` with matching `int` key, falling back to `default`
3523  if none match, or `max(keys)` if no `default` is provided. The keys must form
3524  a contiguous set from `0` to `len(branch_fns) - 1`.
3525
3526  `tf.switch_case` supports nested structures as implemented in `tf.nest`. All
3527  callables must return the same (possibly nested) value structure of lists,
3528  tuples, and/or named tuples.
3529
3530  **Example:**
3531
3532  Pseudocode:
3533
3534  ```c++
3535  switch (branch_index) {  // c-style switch
3536    case 0: return 17;
3537    case 1: return 31;
3538    default: return -1;
3539  }
3540  ```
3541  or
3542  ```python
3543  branches = {0: lambda: 17, 1: lambda: 31}
3544  branches.get(branch_index, lambda: -1)()
3545  ```
3546
3547  Expressions:
3548
3549  ```python
3550  def f1(): return tf.constant(17)
3551  def f2(): return tf.constant(31)
3552  def f3(): return tf.constant(-1)
3553  r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3)
3554  # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3})
3555  ```
3556
3557  Args:
3558    branch_index: An int Tensor specifying which of `branch_fns` should be
3559      executed.
3560    branch_fns: A `dict` mapping `int`s to callables, or a `list` of
3561      (`int`, callable) pairs, or simply a list of callables (in which case the
3562      index serves as the key). Each callable must return a matching structure
3563      of tensors.
3564    default: Optional callable that returns a structure of tensors.
3565    name: A name for this operation (optional).
3566
3567  Returns:
3568    The tensors returned by the callable identified by `branch_index`, or those
3569    returned by `default` if no key matches and `default` was provided, or those
3570    returned by the max-keyed `branch_fn` if no `default` is provided.
3571
3572  Raises:
3573    TypeError: If `branch_fns` is not a list/dictionary.
3574    TypeError: If `branch_fns` is a list but does not contain 2-tuples or
3575               callables.
3576    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3577               callable.
3578  """
3579  return _indexed_case_helper(branch_fns, default, branch_index, name)
3580
3581
3582class XLAControlFlowContext(ControlFlowContext):
3583  """Base class for XLA and TPU control flow contexts."""
3584
3585  def __init__(self):
3586    super(XLAControlFlowContext, self).__init__()
3587    self._name = "XLAControlFlowContext"
3588
3589  def to_control_flow_context_def(self, context_def, export_scope=None):
3590    # pylint: disable=useless-super-delegation
3591    # NOTE(slebedev): the method is required by `ControlFlowContext`.
3592    super(XLAControlFlowContext,
3593          self).to_control_flow_context_def(context_def, export_scope)
3594
3595  def IsXLAContext(self):
3596    return True
3597
3598  def AddOp(self, _):
3599    pass
3600
3601  def AddValue(self, x):
3602    return x
3603
3604
3605def from_control_flow_context_def(context_def, import_scope=None):
3606  """Deserializes `context_def` into the appropriate ControlFlowContext.
3607
3608  Args:
3609    context_def: ControlFlowContextDef proto
3610    import_scope: Optional `string`. Name scope to add.
3611
3612  Returns:
3613    A ControlFlowContext subclass
3614  """
3615  if context_def.HasField("cond_ctxt"):
3616    return CondContext.from_proto(
3617        context_def.cond_ctxt, import_scope=import_scope)
3618  if context_def.HasField("while_ctxt"):
3619    return WhileContext.from_proto(
3620        context_def.while_ctxt, import_scope=import_scope)
3621  raise NotImplementedError("Unknown ControlFlowContextDef field: %s" %
3622                            context_def.WhichOneof("ctxt"))
3623
3624
3625ops.register_proto_function(
3626    ops.GraphKeys.COND_CONTEXT,
3627    proto_type=control_flow_pb2.CondContextDef,
3628    to_proto=CondContext.to_proto,
3629    from_proto=CondContext.from_proto)
3630
3631ops.register_proto_function(
3632    ops.GraphKeys.WHILE_CONTEXT,
3633    proto_type=control_flow_pb2.WhileContextDef,
3634    to_proto=WhileContext.to_proto,
3635    from_proto=WhileContext.from_proto)
3636