• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Control flow statements: loops, conditionals, etc.
16
17Note: most of these operators accept pairs of get_state/set_state functions, to
18capture mutations that the corresponding code blocks might make. These
19mutations only need to be captured when staging the control flow, and they just
20work when reverting to Python behavior.
21
22__Examples__
23
24```
25while cond:
26  self.x += i
27```
28
29When the functionalized version is executed as a Python loop, it just works:
30
31```
32def loop_body():
33  self.x += i     # works as expected for Python loops
34```
35
36But it won't work for TF loops:
37
38```
39def loop_body():
40  self.x += i     # self.x has the wrong value!
41```
42
43get_state/set_state allow piping the mutations through the loop variables as
44well, in effect changing the loop body:
45
46```
47def loop_body(self_x):
48  self.x = self_x  # self.x now has the proper value
49  self.x += i      # the original block
50  self_x = self.x  # write self.x back into the loop vars
51  return self_x
52
53self_x = tf.while_loop(...)
54self.x = self_x    # the result is not properly captured
55```
56"""
57
58import functools
59import sys
60import traceback
61
62import numpy as np
63
64from tensorflow.python.autograph.operators import py_builtins
65from tensorflow.python.autograph.operators import variables
66from tensorflow.python.autograph.utils import ag_logging
67from tensorflow.python.autograph.utils import misc
68from tensorflow.python.autograph.utils import tensors
69from tensorflow.python.data.experimental.ops import take_while_ops
70from tensorflow.python.data.ops import dataset_ops
71from tensorflow.python.data.ops import iterator_ops
72from tensorflow.python.framework import constant_op
73from tensorflow.python.framework import dtypes
74from tensorflow.python.framework import errors_impl
75from tensorflow.python.framework import func_graph
76from tensorflow.python.framework import ops
77from tensorflow.python.framework import tensor_shape
78from tensorflow.python.framework import tensor_util
79from tensorflow.python.ops import array_ops
80from tensorflow.python.ops import control_flow_ops
81from tensorflow.python.ops import control_flow_util
82from tensorflow.python.ops import math_ops
83from tensorflow.python.ops import tensor_array_ops
84from tensorflow.python.ops.ragged import ragged_tensor
85from tensorflow.python.types import distribute
86from tensorflow.python.util import nest
87from tensorflow.python.util import variable_utils
88
89
90PYTHON_MAX_ITERATIONS = 100000000  # Fails in about one minute for empty loops.
91WARN_INEFFICIENT_UNROLL = True
92INEFFICIENT_UNROLL_MIN_ITERATIONS = 50000
93INEFFICIENT_UNROLL_MIN_OPS = 1
94
95
96# TODO(mdan): Use the custom operator pattern instead of type dispatch.
97# An example of this pattern is found in the implementation of distributed
98# datasets. Before it can be used though, we need to standardize the interface.
99
100
101def _is_none_or_undef(value):
102  """Tests whether a value is None or undefined.
103
104  AutoGraph represents undefined symbols using special objects of type Undefined
105  or UndefinedReturnValue.
106
107  Args:
108    value: value to test
109
110  Returns:
111    Boolean
112  """
113  return ((value is None)
114          or isinstance(value, variables.UndefinedReturnValue)
115          or isinstance(value, variables.Undefined))
116
117
118def _verify_tf_condition(cond, tag):
119  """Ensures that the condition can be used in a TF control flow."""
120  extra_hint = 'to check for None, use `is not None`'
121  cond = ops.convert_to_tensor_v2(cond)
122
123  if cond.dtype != dtypes.bool:
124    raise ValueError(
125        'condition of {} expected to be `tf.bool` scalar, got {}'
126        '; to use as boolean Tensor, use `tf.cast`'
127        '; {}'.format(tag, cond, extra_hint))
128
129  if cond.shape is None or cond.shape.ndims is None:
130    # TODO(mdan): Consider a explicit size check, if not too slow.
131    cond = array_ops.reshape(cond, ())
132
133  elif cond.shape.ndims > 0:
134    known_dims = [d for d in cond.shape.as_list() if d is not None]
135    if np.prod(known_dims) > 1:
136      raise ValueError(
137          'condition of {} expected to be `tf.bool` scalar, got {}'
138          '; {}'.format(tag, cond, extra_hint))
139    else:
140      cond = array_ops.reshape(cond, ())
141
142  return cond
143
144
145def _verify_loop_init_vars(init_vars,
146                           symbol_names,
147                           first_iter_vars=None,
148                           extra_message=None):
149  """Ensures that all values in the state are valid to use in a TF loop.
150
151  The init_vars may contain placeholder values derived from first_iter_vars.
152
153  Args:
154    init_vars: initial loop variables (as taken before entering the loop)
155    symbol_names: corresponding names of the initial loop variables
156    first_iter_vars: loop variables after one iteration of the loop
157    extra_message: an extra string to append to the error message, in case of
158      "undefined variable" errors (see variables.Undefined)
159  """
160  if not symbol_names:
161    return
162  if first_iter_vars is None:
163    first_iter_vars = (None,) * len(symbol_names)
164
165  assert len(symbol_names) == len(init_vars)
166  assert len(symbol_names) == len(first_iter_vars)
167  for name, val, fi_val in zip(symbol_names, init_vars, first_iter_vars):
168    if isinstance(val, variables.UndefinedReturnValue):
169      if fi_val:
170        raise ValueError(
171            'the return value from a TensorFlow loop may only be a {}; got {}'
172            .format(LEGAL_LOOP_TYPES, type(fi_val)))
173      else:
174        # TODO(mdan): This can be handled by removing the return value.
175        raise NotImplementedError(
176            'a return statement cannot be placed inside this TensorFlow loop;'
177            ' this may happen if a return statement depends on a'
178            ' static Python condition such as a hyperparameter')
179
180    error_msg = None
181    if val is None:
182      error_msg = "'{}' may not be None before the loop".format(name)
183    elif isinstance(val, variables.Undefined):
184      error_msg = "'{}' must be defined before the loop".format(name)
185      if extra_message:
186        error_msg += '\n' + extra_message
187
188    if error_msg is not None:
189      raise ValueError(error_msg)
190
191
192def _is_subshape(left, right):
193  """Returns True if left shape is at least as specific as right shape."""
194  # TODO(mdan): This code should be in TensorShape.
195  # Note: this is not the same as TensorShape.is_compatible_with, which is
196  # symmetric.
197  # This code also duplicates _ShapeLessThanOrEqual from  control_flow_ops.py.
198  if right.dims is None:
199    return True
200  if left.ndims != right.ndims:
201    return False
202  for ldim, rdim in zip(left.dims, right.dims):
203    if rdim.value is not None and ldim.value != rdim.value:
204      return False
205  return True
206
207
208# TODO(mdan): Remove these verifications once TF ops can properly report names.
209def _verify_single_loop_var(
210    name, check_shape, init, entry, exit_, shape_invariant):
211  """Verifies whether the initial, entry and exit values are consistent."""
212  assert entry is not None, "no TF op should set '{}' to None?".format(name)
213  if exit_ is None:
214    raise ValueError("'{}' is None at the end of the iteration.".format(name))
215
216  if isinstance(init, (bool, int, float, str, np.ndarray)):
217    init = ops.convert_to_tensor_v2(init)
218  if isinstance(entry, (bool, int, float, str, np.ndarray)):
219    entry = ops.convert_to_tensor_v2(entry)
220  if isinstance(exit_, (bool, int, float, str, np.ndarray)):
221    exit_ = ops.convert_to_tensor_v2(exit_)
222
223  if (not tensor_util.is_tf_type(entry) or
224      not tensor_util.is_tf_type(exit_)):
225    return
226
227  # TODO(mdan): Properly account for CompositeTensors.
228  if (not hasattr(entry, 'dtype') or
229      not hasattr(exit_, 'dtype')):
230    return
231  if (not hasattr(entry, 'shape') or
232      not hasattr(exit_, 'shape')):
233    return
234
235  if entry.dtype != exit_.dtype:
236    raise TypeError(
237        "'{}' has dtype {} before the loop, but dtype {} after one"
238        ' iteration'.format(
239            name,
240            entry.dtype.name,
241            exit_.dtype.name,
242        ))
243  if check_shape:
244    exit_shape = exit_.shape
245    if shape_invariant is None:
246      entry_shape = entry.shape
247      if not _is_subshape(exit_shape, entry_shape):
248        raise ValueError(
249            "'{}' has shape {} before the loop, but shape {} after one"
250            ' iteration. Use tf.autograph.experimental.set_loop_options to set'
251            ' shape invariants.'.format(name, entry_shape, exit_shape))
252    else:
253      init_shape = init.shape
254      if not _is_subshape(init_shape, shape_invariant):
255        raise ValueError(
256            "'{}' has shape {} before the loop, which does not conform with"
257            ' the shape invariant {}.'.format(name, init_shape,
258                                              shape_invariant))
259      if not _is_subshape(exit_shape, shape_invariant):
260        raise ValueError(
261            "'{}' has shape {} after one iteration, which does not conform with"
262            ' the shape invariant {}.'.format(
263                name, exit_shape, shape_invariant))
264
265
266def _verify_tf_loop_vars(init_vars,
267                         iter_entry_vars,
268                         iter_exit_vars,
269                         symbol_names,
270                         opts,
271                         check_shapes=True):
272  """Verifies loop variables for consistency."""
273  if check_shapes and 'shape_invariants' in opts:
274    shape_invariants = opts['shape_invariants']
275  else:
276    shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars)
277
278  assert len(symbol_names) == len(shape_invariants)
279  assert len(symbol_names) == len(init_vars)
280  assert len(symbol_names) == len(iter_entry_vars)
281  assert len(symbol_names) == len(iter_exit_vars)
282
283  for i in range(len(symbol_names)):
284    name = symbol_names[i]
285    init = init_vars[i]
286    entry = iter_entry_vars[i]
287    exit_ = iter_exit_vars[i]
288    invariant = shape_invariants[i]
289
290    try:
291      nest.assert_same_structure(init, entry, expand_composites=True)
292    except (ValueError, TypeError):
293      # `Variable`s in `init` may be implicitly converted to `Tensor`s. Convert
294      # `ResourceVariable`s to Tensors so tf.nest.assert_same_structure
295      # won't break due to type spec mismatches between `ResourceVariable`s and
296      # `Tensor`s.
297      try:
298        init_tensors = variable_utils.convert_variables_to_tensors(init)
299        nest.assert_same_structure(init_tensors, entry, expand_composites=True)
300      except (ValueError, TypeError) as e:
301        raise TypeError("'{}' does not have the same nested structure after one"
302                        ' iteration.\n\n{}'.format(name, e)) from e
303
304    try:
305      nest.assert_same_structure(entry, exit_, expand_composites=True)
306    except (ValueError, TypeError) as e:
307      raise TypeError("'{}' does not have the same nested structure after one"
308                      ' iteration.\n\n{}'.format(name, e)) from e
309    if invariant is not None:
310      try:
311        nest.assert_same_structure(init, invariant, expand_composites=False)
312      except (ValueError, TypeError) as e:
313        raise TypeError("'{}' does not have the same nested structure as its"
314                        ' corresponding shape invariant.\n\n{}'.format(
315                            name, e)) from e
316
317    nest.map_structure(
318        functools.partial(_verify_single_loop_var, name, check_shapes), init,
319        entry, exit_, invariant)
320
321
322def verify_single_cond_var(name, body_var, orelse_var):
323  """Verifies whether body_var and orelse_var are consistent."""
324  if body_var is None:
325    raise ValueError("'{}' is None at the end of the main branch.".format(name))
326  if orelse_var is None:
327    raise ValueError(
328        "'{}' is None at the end of the else branch.".format(name))
329
330  if isinstance(body_var, (bool, int, float, str, np.ndarray)):
331    body_var = ops.convert_to_tensor_v2(body_var)
332
333  if isinstance(orelse_var, (bool, int, float, str, np.ndarray)):
334    orelse_var = ops.convert_to_tensor_v2(orelse_var)
335
336  if (not tensor_util.is_tf_type(body_var) or
337      not tensor_util.is_tf_type(orelse_var)):
338    return
339
340  # TODO(mdan): Properly account for CompositeTensors.
341  if (not hasattr(body_var, 'dtype') or
342      not hasattr(orelse_var, 'dtype')):
343    return
344
345  if body_var.dtype != orelse_var.dtype:
346    raise TypeError(
347        "'{}' has dtype {} in the main branch, but dtype {} in the else"
348        ' branch'.format(name, body_var.dtype.name,
349                         orelse_var.dtype.name))
350
351
352def _verify_tf_cond_branch_vars(vars_, symbol_names, branch_name):
353  """Verifies variables output by a conditional branch for consistency."""
354  for name, var_ in zip(symbol_names, vars_):
355    if isinstance(var_, variables.Undefined):
356      raise ValueError(
357          "'{}' must also be initialized in the {} branch".format(
358              name, branch_name))
359    if isinstance(var_, variables.UndefinedReturnValue):
360      raise ValueError(
361          'the {} branch must also have a return statement.'.format(
362              branch_name))
363
364
365def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names):
366  """Verifies variables manipulated by a conditional for consistency."""
367  named_vars = zip(symbol_names, body_vars, orelse_vars)
368
369  for name, body_var, orelse_var in named_vars:
370    try:
371      nest.assert_same_structure(body_var, orelse_var, expand_composites=True)
372    except (ValueError, TypeError):
373      # One branch of cond could be a `Tensor`, while the other branch could be
374      # a `ResourceVariable`. Convert `ResourceVariable`s to `Tensor`s so
375      # assert_same_structure won't fail.
376      try:
377        body_var_tensors = variable_utils.convert_variables_to_tensors(body_var)
378        orelse_var_tensors = variable_utils.convert_variables_to_tensors(
379            orelse_var)
380        nest.assert_same_structure(body_var_tensors, orelse_var_tensors,
381                                   expand_composites=True)
382      except (ValueError, TypeError) as e:
383        raise TypeError(
384            "'{}' must have the same nested structure in the main and else"
385            ' branches:\n\n{}'.format(name, str(e))) from e
386    nest.map_structure(
387        functools.partial(verify_single_cond_var, name), body_var, orelse_var)
388
389
390def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts):
391  """Functional form of a for statement.
392
393  The loop operates on a state, which includes all symbols that are
394  variant across loop iterations, excluding the variables local to the loop.
395
396  For example, given the loop below that calculates the geometric and
397  arithmetic means or some numbers:
398
399  ```
400    geo_mean = 1
401    arith_mean = 0
402    for i in range(n):
403      a = numbers[i]
404      geo_mean *= a
405      arith_mean += a
406  ```
407
408  The state is represented by the variables geo_mean and arith_mean. The
409  `extra_test`, `body`, `get_state` and `set_state` functions must bind to the
410  original `geo_mean` and `arith_mean` symbols, using `nonlocal`.
411
412  The inputs and outputs of the callables representing the loop blocks are not
413  explicit - instead, these functions must use nonlocal/global for side effects.
414  The inputs and outputs are instead controlled by the set_state/get_state
415  functions.
416
417  Args:
418    iter_: The entity being iterated over.
419    extra_test: Callable with boolean return type. An additional loop condition.
420    body: Callable representing the actual loop body.
421    get_state: Additional callable which can capture additional state (such as
422      the values of composite symbols). This is only useful when staging the
423      loop.
424    set_state: Additional callable which save values captured by get_state back
425      into the Python environment. This is only useful when staging the loop.
426    symbol_names: Tuple containing names of the loop variables returned by
427      get_state.
428    opts: Optional dict of extra loop parameters.
429  """
430  if tensor_util.is_tf_type(iter_):
431    if tensors.is_range_tensor(iter_):
432      _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
433                         symbol_names, opts)
434    elif isinstance(iter_, ragged_tensor.RaggedTensor):
435      _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
436                          symbol_names, opts)
437    else:
438      _known_len_tf_for_stmt(
439          iter_, extra_test, body, get_state, set_state, symbol_names, opts)
440
441  elif isinstance(iter_, dataset_ops.DatasetV2):
442    _tf_dataset_for_stmt(
443        iter_, extra_test, body, get_state, set_state, symbol_names, opts)
444
445  elif isinstance(iter_, iterator_ops.OwnedIterator):
446    _tf_iterator_for_stmt(
447        iter_, extra_test, body, get_state, set_state, symbol_names, opts)
448
449  elif isinstance(iter_, ragged_tensor.RaggedTensor):
450    _tf_ragged_for_stmt(
451        iter_, extra_test, body, get_state, set_state, symbol_names, opts)
452
453  elif isinstance(iter_, distribute.Iterator):
454    _tf_iterator_for_stmt(
455        iter_, extra_test, body, get_state, set_state, symbol_names, opts)
456
457  elif isinstance(iter_, distribute.Iterable):
458    # TODO(b/162250181): Use _tf_iterator_for_stmt(iter(iter_)...
459    _tf_distributed_iterable_for_stmt(
460        iter_, extra_test, body, get_state, set_state, symbol_names, opts)
461
462  else:
463    _py_for_stmt(iter_, extra_test, body, None, None)
464
465
466def _py_for_stmt(iter_, extra_test, body, get_state, set_state):
467  """Overload of for_stmt that executes a Python for loop."""
468  del get_state, set_state
469
470  if __debug__:
471    checker = _PythonLoopChecker()
472    before_iteration = checker.before_iteration
473    after_iteration = checker.after_iteration
474    before_iteration()
475
476    original_body = body
477    def protected_body(protected_iter):
478      original_body(protected_iter)
479      after_iteration()
480      before_iteration()
481    body = protected_body
482
483  if extra_test is not None:
484    def guarded_extra_test():
485      extra_test_result = extra_test()
486      try:
487        # Note: Using try/except and not tensor_util.is_tf_type to avoid
488        # performance degradation.
489        return bool(extra_test_result)
490      except errors_impl.OperatorNotAllowedInGraphError as e:
491        ag_logging.log(
492            1,
493            'Caught error while evaluating loop stop condition',
494            exc_info=True)
495        # TODO(mdan): We can pass the location of extra_test and show it here.
496        raise NotImplementedError(
497            'break and return statements which depend on a TF condition are not'
498            ' supported in Python for loops. Did you intend to make it a TF'
499            ' loop?\nSee '
500            'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
501            'python/autograph/g3doc/reference/limitations.md'
502            '#consistency-of-control-flow-types for more info.') from e
503
504    if guarded_extra_test():
505      for target in iter_:
506        body(target)
507        if not guarded_extra_test():
508          break
509
510  else:
511    for target in iter_:
512      body(target)
513
514
515def _add_max_iterations_hint(opts, n):
516  # TODO(b/159186914): Remove the safeguard, and always set maximum_iterations.
517  if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
518    opts['maximum_iterations'] = n
519
520
521def _known_len_tf_for_stmt(
522    iter_, extra_test, body, get_state, set_state, symbol_names, opts):
523  """Overload of for_stmt that iterates over TF entities that admit a length."""
524  n = py_builtins.len_(iter_)
525
526  # TODO(b/117628877): Revisit performance once XLA has the necessary support.
527  # Note: using a TensorArray creates an extra copy, but can calculate
528  # gradients more efficiently than StridedSlice.
529  ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
530  iter_ = ta.unstack(iter_)
531
532  iterate_index = 0
533
534  def aug_get_state():
535    return (iterate_index,) + get_state()
536
537  def aug_set_state(aug_loop_vars):
538    nonlocal iterate_index
539    # TODO(b/171479293): Drop the lint override.
540    iterate_index, *loop_vars = aug_loop_vars  # pylint:disable=unused-variable
541    # The iteration index is not "output" by the for loop. If the iterate
542    # is used outside the loop, it will appear in the loop vars separately.
543    set_state(loop_vars)
544
545  def aug_body():
546    nonlocal iterate_index
547    body(iter_.read(iterate_index))
548    iterate_index += 1
549
550  def aug_test():
551    main_test = iterate_index < n
552    if extra_test is not None:
553      return control_flow_ops.cond(main_test, extra_test, lambda: False)
554    return main_test
555
556  _add_max_iterations_hint(opts, n)
557
558  _tf_while_stmt(
559      aug_test,
560      aug_body,
561      aug_get_state,
562      aug_set_state,
563      ('<internal iterate>',) + symbol_names,
564      opts,
565  )
566
567
568def _tf_ragged_for_stmt(
569    iter_, extra_test, body, get_state, set_state, symbol_names, opts):
570  """Overload of for_stmt that iterates over TF ragged tensors."""
571  init_vars = get_state()
572  _verify_loop_init_vars(init_vars, symbol_names)
573
574  # TODO(mdan): Move this into len()? Requires eager support.
575  if iter_.shape and iter_.shape[0] is not None:
576    n = iter_.shape[0]
577  else:
578    n = iter_.row_lengths()[0]
579
580  iterate_index = 0
581
582  def aug_get_state():
583    return (iterate_index,) + get_state()
584
585  def aug_set_state(aug_loop_vars):
586    nonlocal iterate_index
587    # TODO(b/171479293): Drop the lint override.
588    iterate_index, *loop_vars = aug_loop_vars  # pylint:disable=unused-variable
589    # The iteration index is not "output" by the for loop. If the iterate
590    # is used outside the loop, it will appear in the loop vars separately.
591    set_state(loop_vars)
592
593  def aug_body():
594    nonlocal iterate_index
595    body(iter_[iterate_index])
596    iterate_index += 1
597
598  def aug_test():
599    main_test = iterate_index < n
600    if extra_test is not None:
601      return control_flow_ops.cond(main_test, extra_test, lambda: False)
602    return main_test
603
604  _add_max_iterations_hint(opts, n)
605
606  _tf_while_stmt(
607      aug_test,
608      aug_body,
609      aug_get_state,
610      aug_set_state,
611      ('<internal iterate>',) + symbol_names,
612      opts)
613
614
615def _tf_range_for_stmt(
616    iter_, extra_test, body, get_state, set_state, symbol_names, opts):
617  """Overload of for_stmt that iterates over a TF range (and elides it)."""
618  start, limit, delta = iter_.op.inputs
619
620  iterate = start
621
622  def _value_or(name, var, default):
623    if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)):
624      return default
625    return var
626
627  def aug_get_state():
628    state_vars = get_state()
629    state_vars = tuple(
630        _value_or(name, var, iterate)
631        for name, var in zip(symbol_names, state_vars))
632    return (iterate,) + state_vars
633
634  def aug_set_state(aug_loop_vars):
635    nonlocal iterate
636    # TODO(b/171479293): Drop the lint override.
637    iterate, *loop_vars = aug_loop_vars  # pylint:disable=unused-variable
638    # The iteration index is not "output" by the for loop. If the iterate
639    # is used outside the loop, it will appear in the loop vars separately.
640    set_state(loop_vars)
641
642  def aug_body():
643    nonlocal iterate
644    body(iterate)
645    iterate += delta
646
647  def aug_test():
648    # TODO(b/159713842): Remove once constant folding works.
649    const_delta = tensor_util.constant_value(delta)
650    if const_delta is not None:
651      if const_delta >= 0:
652        main_test = iterate < limit
653      else:
654        main_test = iterate > limit
655    else:
656      main_test = math_ops.logical_or(
657          math_ops.logical_and(delta >= 0, iterate < limit),
658          math_ops.logical_and(delta < 0, iterate > limit))
659
660    if extra_test is not None:
661      main_test = control_flow_ops.cond(main_test, extra_test, lambda: False)
662    return main_test
663
664  _add_max_iterations_hint(
665      opts,
666      math_ops.cast(misc.get_range_len(start, limit, delta), dtypes.int32))
667
668  _tf_while_stmt(
669      aug_test,
670      aug_body,
671      aug_get_state,
672      aug_set_state,
673      ('<internal iterate>',) + symbol_names,
674      opts)
675
676
677def _tf_iterator_for_stmt(
678    iter_, extra_test, body, get_state, set_state, symbol_names, opts):
679  """Overload of for_stmt that iterates over TF Iterators. See for_loop."""
680  symbol_names = ('<internal has_next>',) + symbol_names
681  has_next = True
682
683  def aug_get_state():
684    return (has_next,) + get_state()
685
686  def aug_set_state(aug_loop_vars):
687    nonlocal has_next
688    # TODO(b/171479293): Drop the lint override.
689    has_next, *loop_vars = aug_loop_vars  # pylint:disable=unused-variable
690    set_state(loop_vars)
691
692  init_vars = aug_get_state()
693  _verify_loop_init_vars(init_vars, symbol_names)
694
695  def aug_body():
696    """Main body passed to _tf_while_stmt."""
697    nonlocal has_next
698    opt_iterate = iter_.get_next_as_optional()
699    has_next = opt_iterate.has_value()
700    loop_vars = aug_get_state()  # updated by set_state() in _tf_while_loop.
701
702    def main_path():
703      body(opt_iterate.get_value())
704      new_loop_vars = aug_get_state()
705      # Note: this verification duplicates the one performed in tf_while_stmt,
706      # but needs to be done earlier to prevent the tf.cond from blowing up
707      # first.
708      _verify_tf_loop_vars(
709          init_vars, loop_vars, new_loop_vars, symbol_names, opts)
710      return new_loop_vars
711
712    def noop_path():
713      return loop_vars
714
715    # TODO(mdan): If tf.while_loop supported Optional, this could be avoided.
716    # Calling set_state so that get_state() _tf_while_loop sees the conditional
717    # tensors.
718    aug_set_state(
719        control_flow_ops.cond(has_next, main_path, noop_path))
720
721  def aug_test():
722    # This value takes a complicated path to get here:
723    #   prev_iteration_body -> get_state -> tf.while_loop (as loop var)
724    #   -> current_iteration_body -> set_state -> has_next
725    main_test = has_next
726    if extra_test is not None:
727      return control_flow_ops.cond(main_test, extra_test, lambda: False)
728    return main_test
729
730  _tf_while_stmt(
731      aug_test,
732      aug_body,
733      aug_get_state,
734      aug_set_state,
735      symbol_names,
736      opts)
737
738
739def _general_purpose_scan(ds, init_state, body):
740  """Variant of Dataset.scan with semantics of general-purpose computation."""
741  # Datasets are typically intended for data preprocessing. However, in
742  # autograph loops they usually appear as general-purpose computations (for
743  # example, a custom training loop). These two use cases require significantly
744  # different optimization policies, the most important of which is the device
745  # placement. The flag override for use_default_device below instructs the
746  # runtime to treat the computation as general-purpose, rather than data
747  # preprocessing.
748  # TODO(mdan): s/use_default_device/specialize_for_input_pipeline.
749  # TODO(mdan): Don't use private symbols.
750  # pylint:disable=protected-access
751  return dataset_ops._ScanDataset(
752      ds, init_state, body, use_default_device=False)
753
754
755def _tf_dataset_for_stmt(
756    ds, extra_test, body, get_state, set_state, symbol_names, opts):
757  """Overload of _dataset_for_stmt with early stopping. See for_stmt."""
758  # Note: This is easier to follow with the insight that the computations in
759  # a dataset pipeline are transposed (aka fused).
760  # For example, given a pipeline input -> scan -> take_while -> reduce,
761  # and a dataset with input [1, 2, 3], the computations occur in the following
762  # order:
763  #  reduce(take_while(scan(1)))
764  #  reduce(take_while(scan(2)))
765  #  reduce(take_while(scan(3)))
766
767  init_vars = get_state()
768  _verify_loop_init_vars(init_vars, symbol_names)
769
770  # Workaround for Dataset.reduce not allowing empty state tensors - create
771  # a dummy state variable that remains unused.
772  # TODO(mdan): reduce should allow and match empty structures.
773  if not init_vars:
774    init_vars = (constant_op.constant(0),)
775    symbol_names = ('<internal dummy>',)
776
777    def dummy_set_state(unused_dummy):
778      pass
779
780    def dummy_get_state():
781      return (constant_op.constant(0),)
782
783    get_state, set_state = dummy_get_state, dummy_set_state
784
785  def scan_body(scan_state, scan_inputs):
786    """Main body of the Dataset.scan."""
787    loop_vars, iterate = scan_state, scan_inputs
788    set_state(loop_vars)
789
790    def main_path():
791      body(iterate)
792      new_loop_vars = get_state()
793      _verify_tf_loop_vars(
794          init_vars, loop_vars, new_loop_vars, symbol_names, opts,
795          check_shapes=False)
796      return new_loop_vars
797
798    if extra_test is not None:
799      extra_cond = extra_test()
800      new_loop_vars = control_flow_ops.cond(
801          extra_cond, main_path, lambda: loop_vars)
802    else:
803      # TODO(mdan): the optimizer should be able to remove an invariant cond?
804      extra_cond = (constant_op.constant(True),)  # dummy value, unused
805      new_loop_vars = main_path()
806
807    scan_outputs = new_loop_vars, extra_cond
808    new_scan_state = new_loop_vars
809    return new_scan_state, scan_outputs
810
811  def take_while_predicate(unused_loop_vars, extra_cond):
812    return extra_cond
813
814  def reduce_body(unused_reduce_state, scan_outputs):
815    output_loop_vars, unused_extra_cond = scan_outputs
816    new_reduce_state = output_loop_vars
817    return new_reduce_state
818
819  ds = _general_purpose_scan(ds, init_vars, scan_body)
820  if extra_test is not None:
821    ds = ds.apply(take_while_ops.take_while(take_while_predicate))
822  final_loop_vars = ds.reduce(init_vars, reduce_body)
823  set_state(final_loop_vars)
824
825
826def _tf_distributed_iterable_for_stmt(
827    iter_, extra_test, body, get_state, set_state, symbol_names, opts):
828  """Overload of for_stmt that iterates over TF distributed datasets."""
829
830  if extra_test is not None:
831    raise NotImplementedError(
832        'break and return statements are not yet supported in '
833        'for ... in distributed input loops.')
834
835  init_vars = get_state()
836  _verify_loop_init_vars(init_vars, symbol_names)
837
838  if 'shape_invariants' in opts:
839    opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
840        opts['shape_invariants'], init_vars)
841
842  def reduce_body(loop_vars, iterate):
843    set_state(loop_vars)
844    body(iterate)
845    new_loop_vars = get_state()
846    _verify_tf_loop_vars(
847        init_vars, loop_vars, new_loop_vars, symbol_names, opts)
848    return new_loop_vars
849
850  set_state(iter_.reduce(init_vars, reduce_body))
851
852
853def while_stmt(test, body, get_state, set_state, symbol_names, opts):
854  """Functional form of a while statement.
855
856  The loop operates on a so-called state, which includes all symbols that are
857  variant across loop iterations. In what follows we refer to state as either
858  a tuple of entities that represent an actual state, or a list of arguments
859  of the corresponding types.
860
861  The inputs and outputs of the callables representing the loop blocks are not
862  explicit - instead, these functions must use nonlocal/global for side effects.
863  The inputs and outputs are instead controlled by the set_state/get_state
864  functions.
865
866  Args:
867    test: Callable with boolean return type. The loop condition.
868    body: Callable representing the actual loop body.
869    get_state: Additional callable which can capture additional state (such as
870      the values of composite symbols). This is only useful when staging the
871      loop.
872    set_state: Additional callable which save values captured by get_state back
873      into the Python environment. This is only useful when staging the loop.
874    symbol_names: Tuple containing the names of all loop variables.
875    opts: Optional dict of extra loop parameters.
876
877  Returns:
878    Tuple containing the final state.
879  """
880
881  # Evaluate the initial test once in order to do the dispatch. The evaluation
882  # is isolated to minimize unwanted side effects.
883  # TODO(mdan): Do a full iteration - some state types might lower to Tensor.
884  with func_graph.FuncGraph('tmp').as_default():
885    init_test = test()
886
887  # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine
888  # with the re-evaluation of `test` that `_tf_while_stmt` will make.
889  if tensors.is_dense_tensor(init_test):
890    _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts)
891    return
892
893  # Normal Python: We already consumed one evaluation of `test`; consistently,
894  # unroll one iteration before dispatching to a normal loop.
895  # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt?
896  if not init_test:
897    return
898  body()
899
900  _py_while_stmt(test, body, get_state, set_state, opts)
901
902
903class _PythonLoopChecker(object):
904  """Verifies Python loops for TF-specific limits."""
905
906  __slots__ = (
907      'iterations',
908      'check_inefficient_unroll',
909      'check_op_count_after_iteration',
910      'ops_before_iteration',
911      )
912
913  def __init__(self):
914    self.iterations = 1
915    self.check_inefficient_unroll = WARN_INEFFICIENT_UNROLL
916
917    # Triggered when we decided to test the op counts.
918    self.check_op_count_after_iteration = False
919
920  def _get_ops(self):
921    return ops.get_default_graph().get_operations()
922
923  def _check_unroll_limits(self):
924    if self.iterations > PYTHON_MAX_ITERATIONS:
925      raise ValueError('iteration limit exceeded')
926
927  def _stop_checking_inefficient_unroll(self):
928    self.check_inefficient_unroll = False
929    self.check_op_count_after_iteration = False
930    self.ops_before_iteration = None
931
932  def _verify_inefficient_unroll(self):
933    """Checks for possibly-inefficient creation of ops in a Python loop."""
934    assert self.ops_before_iteration is not None
935    ops_after_iteration = self._get_ops()
936    new_ops = tuple(
937        op for op in ops_after_iteration if op not in self.ops_before_iteration)
938
939    if len(new_ops) < INEFFICIENT_UNROLL_MIN_OPS:
940      return False
941
942    ag_logging.warning(
943        'Large unrolled loop detected. Did you mean to use a TF loop?'
944        ' The following ops were created after iteration %s: %s'
945        '\nSee'
946        ' https://github.com/tensorflow/tensorflow/blob/master/'
947        'tensorflow/python/autograph/g3doc/reference/common_errors.md'
948        '#warning-large-unrolled-loop-detected'
949        '\n'
950        'Location:'
951        '\n%s'
952        '', self.iterations, new_ops, '\n'.join(traceback.format_stack()))
953    return True
954
955  def before_iteration(self):
956    """Called before each iteration in a Python loop."""
957    if (self.check_inefficient_unroll and
958        self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS):
959      self.ops_before_iteration = self._get_ops()
960      self.check_op_count_after_iteration = True
961
962  def after_iteration(self):
963    """Called after each iteration in a Python loop."""
964    self.iterations += 1
965
966    self._check_unroll_limits()
967
968    if self.check_op_count_after_iteration:
969      did_warn = self._verify_inefficient_unroll()
970      if did_warn:
971        self._stop_checking_inefficient_unroll()  # Only warn once.
972      elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3:
973        # Once deciding to check the op counts, only do it for a few iterations.
974        self._stop_checking_inefficient_unroll()
975
976
977def _py_while_stmt(test, body, get_state, set_state, opts):
978  """Overload of while_stmt that executes a Python while loop."""
979  del opts, get_state, set_state
980
981  if __debug__:
982    checker = _PythonLoopChecker()
983    before_iteration = checker.before_iteration
984    after_iteration = checker.after_iteration
985    before_iteration()
986
987    original_body = body
988    def protected_body():
989      original_body()
990      after_iteration()
991      before_iteration()
992    body = protected_body
993
994  def guarded_test():
995    test_result = test()
996    try:
997      # Note: Using try/except and not tensor_util.is_tf_type to avoid
998      # performance degradation.
999      return bool(test_result)
1000    except errors_impl.OperatorNotAllowedInGraphError as e:
1001      ag_logging.log(
1002          1,
1003          'Caught error while evaluating while loop condition',
1004          exc_info=True)
1005      # TODO(mdan): distinguish beteen these two cases.
1006      raise NotImplementedError(
1007          'The condition of while loop started as non-Tensor, then changed to'
1008          ' Tensor. This may happen either because variables changed type, or'
1009          ' when a break or return statement inside the loop depends on a'
1010          ' Tensor condition. In both cases, changing to a TF loop should'
1011          ' remove the error.\nSee '
1012          'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
1013          'python/autograph/g3doc/reference/limitations.md'
1014          '#consistency-of-control-flow-types for more info.') from e
1015  while guarded_test():
1016    body()
1017
1018
1019def _shape_invariants_mapping_to_positional_list(mapping, keys):
1020  # The keys are not expected to be hashable.
1021  mapping = {id(k): (k, v) for k, v in mapping}
1022  result = []
1023  for k in keys:
1024    map_key, map_val = mapping.get(id(k), (None, None))
1025    result.append(
1026        map_val if map_key is k else nest.map_structure(lambda _: None, k))
1027  return tuple(result)
1028
1029
1030# Textual description of what a legal TF loop variable is. This description
1031# summarizes types that _placeholder_value below can handle. Keep the two
1032# together and in sync.
1033LEGAL_LOOP_TYPES = 'Tensor, int, float, bool or a list, tuple or dict thereof'
1034
1035
1036def _placeholder_value(like, shape_invariant, original=None):
1037  """Constructs a (dummy) placeholder value for a loop-initialized variable.
1038
1039  Args:
1040    like: Any object. The value created by the first iteration of the loop. If a
1041      Python scalar, the placeholder will be the zero value of that type. If a
1042      Tensor, the placeholder will be a zero tensor of matching shape and dtype.
1043      If a list, dict or tuple, the placeholder will be an identical structure
1044      of placeholders.
1045    shape_invariant: The shape invariant specified by the user (or None, if
1046      nothing was specified) for the respective variable.
1047    original: Any object. The value of the variable prior to entering the loop.
1048      Typically, this is one of the special "Undefined" value, because that's
1049      when a placeholder is needed.
1050
1051  Returns:
1052    Either a zero value of structure, shape and dtype mathing 'like', or
1053    'original', if no such zero value could be created.
1054  """
1055  if like is None:
1056    return original, None
1057
1058  elif isinstance(like, (variables.Undefined, variables.UndefinedReturnValue)):
1059    return original, None
1060
1061  elif isinstance(like, (int, float, bool)):
1062    return type(like)(0), None
1063
1064  elif tensor_util.is_tf_type(like):
1065
1066    like_shape = shape_invariant if shape_invariant is not None else like.shape
1067    if like_shape is None or like_shape.rank is None:
1068      return array_ops.zeros((), like.dtype), like_shape
1069
1070    # If the shape contains dynamic values, set the corresponding starting
1071    # dimension to either zero or what the shape invariant specified.
1072    placeholder_shape = []
1073    has_dynamic_dims = False
1074    for s, i in zip(like.shape, like_shape):
1075      if i is None:
1076        like_dim = 0
1077      elif isinstance(i, tensor_shape.Dimension):
1078        if i.value is None:
1079          like_dim = 0
1080        else:
1081          like_dim = i.value
1082      else:
1083        like_dim = i
1084
1085      if s is None:
1086        placeholder_shape.append(like_dim)
1087        has_dynamic_dims = True
1088      elif isinstance(s, tensor_shape.Dimension):
1089        if s.value is None:
1090          placeholder_shape.append(like_dim)
1091          has_dynamic_dims = True
1092        else:
1093          placeholder_shape.append(s.value)
1094      else:
1095        placeholder_shape.append(s)
1096
1097    if has_dynamic_dims:
1098      invariant = like_shape
1099    else:
1100      invariant = None
1101
1102    return array_ops.zeros(placeholder_shape, like.dtype), invariant
1103
1104  elif isinstance(like, (list, tuple, dict)):
1105    if shape_invariant is None:
1106      zipped = nest.map_structure(lambda v: _placeholder_value(v, None),
1107                                  nest.flatten(like))
1108    else:
1109      zipped = nest.map_structure(_placeholder_value, nest.flatten(like),
1110                                  nest.flatten(shape_invariant))
1111    vals, invars = zip(*zipped)
1112    return (nest.pack_sequence_as(like,
1113                                  vals), nest.pack_sequence_as(like, invars))
1114
1115  # This is to be caught by _try_handling_undefineds, to give more context.
1116  raise TypeError(
1117      "Found an unsupported type '{}' while creating placeholder for {}."
1118      ' Supported types include Tensor, int, float, bool, list, tuple or dict.'
1119      .format(type(like).__name__, like))
1120
1121
1122def _try_handling_undefineds(body, get_state, set_state, init_vars, nulls,
1123                             shape_invariants, symbol_names):
1124  """Makes a best-effort attempt to substitute undefineds with placeholders.
1125
1126  Note: this substitution requires two things to happen:
1127   1. the types of loop variables could be inferred (usually by staging one
1128       iteration)
1129   2. these types could be replaced by placeholders (e.g. zero values, for
1130       tensors.
1131
1132  Args:
1133    body: a function representing the loop body. See while_stmt.
1134    get_state: state getter for the loop statement. See while_stmt.
1135    set_state: state getter for the loop statement. See while_stmt.
1136    init_vars: loop variables before entering the loop. See while_stmt.
1137    nulls: list of boolean flags indicating whether the corresponding loop var
1138      is None or undefined.
1139    shape_invariants: user-specified shape invariant for each loop variable.
1140    symbol_names: list of loop variable names. See while_stmt.
1141
1142  Returns:
1143    A tuple (success, new_init_vars, extra_shape_invariants, failure_message):
1144     * success is a boolean flag indicating
1145       whether types could be successfully inferred (step 1 above)
1146     * new_init_vars contains the loop vars, with None or undefined values
1147       replaced by default values, where possible (step 2 above)
1148     * extra_shape_invariants contains shape invariants that would be needed
1149       by while_stmt, for instance if the placeholder values had a shape
1150       different from the corresponding loop outputs
1151  """
1152  state_modified = False
1153  first_iter_vars = None
1154  failure_message = None
1155
1156  try:
1157    # Stage an iteration of the loop body in a temporary graph.
1158    with func_graph.FuncGraph('tmp').as_default():
1159      # This call to set_state helps report nicer error messages when symbols
1160      # are inconsistently used.
1161      # Another complication is that non_tensor values will be autocast to
1162      # Tensor by while_loop, and their static value lost. So we need to account
1163      # that here.
1164      def autocast_to_tensor(v):
1165        if isinstance(
1166            v, (int, float, bool, str, list, tuple, np.ndarray, np.generic)):
1167          init_val = ops.convert_to_tensor_v2(v)
1168          return array_ops.placeholder(init_val.dtype, init_val.shape)
1169        return v
1170      autocast_init_vars = nest.map_structure(autocast_to_tensor, init_vars)
1171      set_state(autocast_init_vars)
1172      state_modified = True
1173
1174      body()
1175      first_iter_vars = get_state()
1176
1177    # Note: the actual placeholder value doesn't matter, because as the
1178    # staging proved, it will be replaced by an actual value before being
1179    # read.
1180    inits_and_invariants = tuple(
1181        (_placeholder_value(iv, i, v) if n else (v, None))
1182        for v, n, iv, i in zip(init_vars, nulls, first_iter_vars,
1183                               shape_invariants))
1184    init_vars, extra_shape_invariants = zip(*inits_and_invariants)
1185    success = True
1186
1187  except (UnboundLocalError, TypeError, ValueError, KeyError):
1188    ag_logging.log(1, 'Caught error while staging loop body', exc_info=True)
1189    # Fall back to the old functionality. It will likely result in an input
1190    # validation failure.
1191    exc = sys.exc_info()
1192    failure_message = (
1193        'Note: AutoGraph tried to define it automatically, but ran into a'
1194        ' {}: {}'.format(exc[0].__name__, exc[1]))
1195
1196  finally:
1197    if state_modified:
1198      set_state(init_vars)
1199
1200  # This check runs regardless, in case we captured non-Tensor inputs.
1201  _verify_loop_init_vars(
1202      init_vars, symbol_names, first_iter_vars, extra_message=failure_message)
1203
1204  return success, init_vars, extra_shape_invariants
1205
1206
1207def _runtime_zero_iterations_errmsg(symbol_names, nulls, init_vars):
1208  """Creates an error message asking for the loop to iterate at least once."""
1209  var_names = []
1210  for sn, n, v in zip(symbol_names, nulls, init_vars):
1211    if not n:
1212      continue
1213    if isinstance(v, variables.UndefinedReturnValue):
1214      var_names.append('the function return value')
1215    else:
1216      var_names.append(sn)
1217  var_names = ', '.join(var_names)
1218  return 'loop must iterate at least once to initialize {}'.format(var_names)
1219
1220
1221def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
1222  """Overload of while_stmt that stages a TF while_stmt."""
1223  init_vars = get_state()
1224  orig_init_vars = init_vars
1225
1226  nulls = tuple(_is_none_or_undef(v) for v in init_vars)
1227  if any(nulls):
1228    shape_invars_by_init_vals = {
1229        id(v): i for v, i in opts.get('shape_invariants', ())
1230    }
1231    shape_invariants = tuple(
1232        shape_invars_by_init_vals.get(id(v), None) for v in orig_init_vars)
1233    (require_one_iteration, init_vars,
1234     extra_shape_invariants) = _try_handling_undefineds(body, get_state,
1235                                                        set_state, init_vars,
1236                                                        nulls, shape_invariants,
1237                                                        symbol_names)
1238  else:
1239    require_one_iteration = False
1240
1241  if require_one_iteration:
1242    merged_shape_invariants = dict(shape_invars_by_init_vals)
1243    # This has two roles:
1244    #  1. Shape invariants are remapped from the old init vars to the new ones.
1245    #  2. Any new shape invariants created by the init vars are kept, but only
1246    #     if the user didn't already specified some.
1247    for v, nv, ni in zip(orig_init_vars, init_vars, extra_shape_invariants):
1248      merged_invariant = merged_shape_invariants.get(id(v), ni)
1249      if merged_invariant is not None:
1250        merged_shape_invariants[id(nv)] = merged_invariant
1251    merged_shape_invariants = tuple((nv, merged_shape_invariants[id(nv)])
1252                                    for nv in init_vars
1253                                    if id(nv) in merged_shape_invariants)
1254    if merged_shape_invariants:
1255      opts = dict(**opts)
1256      opts['shape_invariants'] = merged_shape_invariants
1257
1258  def aug_test(*loop_vars):
1259    if require_one_iteration:
1260      loop_vars = loop_vars[1:]
1261
1262    set_state(loop_vars)
1263    return _verify_tf_condition(test(), 'while loop')
1264
1265  def aug_body(*loop_vars):
1266    if require_one_iteration:
1267      loop_vars = loop_vars[1:]
1268
1269    set_state(loop_vars)
1270    body()
1271    new_loop_vars = get_state()
1272    _verify_tf_loop_vars(
1273        init_vars, loop_vars, new_loop_vars, symbol_names, opts)
1274
1275    if require_one_iteration:
1276      new_loop_vars = (True,) + new_loop_vars
1277
1278    return new_loop_vars
1279
1280  if 'shape_invariants' in opts:
1281    opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
1282        opts['shape_invariants'], init_vars)
1283
1284  while_loop_opts = dict(opts)
1285  while_loop_opts.pop('iterate_names', None)
1286
1287  # Non-v2 while_loop unpacks the results when there is only one return value.
1288  # This enforces consistency across versions.
1289  while_loop_opts['return_same_structure'] = True
1290
1291  if require_one_iteration:
1292    aug_init_vars = (False,) + init_vars
1293    if 'shape_invariants' in while_loop_opts:
1294      while_loop_opts['shape_invariants'] = (
1295          (None,) + while_loop_opts['shape_invariants'])
1296  else:
1297    aug_init_vars = init_vars
1298
1299  final_loop_vars = control_flow_ops.while_loop(
1300      aug_test, aug_body, aug_init_vars, **while_loop_opts)
1301
1302  if require_one_iteration:
1303    with ops.control_dependencies([
1304        control_flow_ops.Assert(final_loop_vars[0], [
1305            _runtime_zero_iterations_errmsg(symbol_names, nulls, orig_init_vars)
1306        ])
1307    ]):
1308      final_loop_vars = nest.map_structure(
1309          lambda v: (array_ops.identity(v) if tensor_util.is_tf_type(v) else v),
1310          final_loop_vars[1:],
1311      )
1312
1313  set_state(final_loop_vars)
1314
1315
1316def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts):
1317  """Functional form of an if statement.
1318
1319  The conditional operates on a state, which includes all symbols whose values
1320  are a function of the branch taken.
1321
1322  For example, given the code below that calculates the abs function:
1323
1324  ```
1325    x = 1
1326    if x > 0:
1327      x = -x
1328  ```
1329
1330  The state is represented by the variable `x`. The `body, `orelse` and
1331  `set_state` functions must bind to the original `x` symbol, using `nonlocal`.
1332
1333  The inputs and outputs of the callables representing the loop blocks are not
1334  explicit - instead, these functions must use nonlocal/global for side effects.
1335  The inputs and outputs are instead controlled by the set_state/get_state
1336  functions.
1337
1338  Args:
1339    cond: Boolean.
1340    body: Callable representing the main block of the conditional.
1341    orelse: Callable representing the else block of the conditional.
1342    get_state: Function that returns a tuple containing the values of all
1343      composite symbols modified within the conditional. This allows access to
1344      state that branches may mutate through side effects. This function is not
1345      needed and should not be called when dispatching to code matching Python's
1346      default semantics. This is useful for checkpointing to avoid unintended
1347      side-effects when staging requires evaluating all code-paths.
1348    set_state: Function to set the values of all composite symbols modified
1349      within the conditional. This is the complement to get_state, used to
1350      restore checkpointed values. The single argument a tuple containing values
1351      for each composite symbol that may be modified in a branch of the
1352      conditional. The is usually the result of a call to get_state.
1353    symbol_names: Tuple containing basic loop var names.
1354    nouts: Number of variables output by the statement. Vars which are not
1355      outputs will not be passed through staged control flow such as tf.cond.
1356      This includes variables that are defined before the conditional, but are
1357      not used after it.
1358  """
1359  # Note: tf.cond doesn't support SparseTensor.
1360  if tensors.is_dense_tensor(cond):
1361    _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
1362  else:
1363    _py_if_stmt(cond, body, orelse)
1364
1365
1366def _tf_if_stmt(
1367    cond, body, orelse, get_state, set_state, symbol_names, nouts):
1368  """Overload of if_stmt that stages a TF cond."""
1369  cond = _verify_tf_condition(cond, 'if statement')
1370
1371  if not nouts:
1372    prev_get_state, prev_set_state = get_state, set_state
1373    # Control flow V1 wants at least one output.
1374    get_state = lambda: (0,) + prev_get_state()
1375    set_state = lambda v: prev_set_state(v[1:])
1376    symbol_names += ('<unused dummy>',)
1377    nouts = 1
1378
1379  init_vars = get_state()
1380
1381  # TODO(mdan): Use nonlocal once we no longer need to support py2.
1382  new_body_vars_ = [None]
1383  new_orelse_vars_ = [None]
1384
1385  def aug_body():
1386    set_state(init_vars)
1387    body()
1388    new_body_vars = get_state()
1389    new_body_vars = new_body_vars[:nouts]
1390    new_body_vars_[0] = new_body_vars
1391    _verify_tf_cond_branch_vars(new_body_vars, symbol_names, 'main')
1392    if new_orelse_vars_[0] is not None:
1393      _verify_tf_cond_vars(new_body_vars, new_orelse_vars_[0], symbol_names)
1394    return new_body_vars
1395
1396  def aug_orelse():
1397    set_state(init_vars)
1398    orelse()
1399    new_orelse_vars = get_state()
1400    new_orelse_vars = new_orelse_vars[:nouts]
1401    new_orelse_vars_[0] = new_orelse_vars
1402    _verify_tf_cond_branch_vars(new_orelse_vars, symbol_names, 'else')
1403    if new_body_vars_[0] is not None:
1404      _verify_tf_cond_vars(new_body_vars_[0], new_orelse_vars, symbol_names)
1405    return new_orelse_vars
1406
1407  final_cond_vars = control_flow_ops.cond(
1408      cond, aug_body, aug_orelse, strict=True)
1409  final_cond_vars = final_cond_vars + init_vars[nouts:]
1410
1411  set_state(final_cond_vars)
1412
1413
1414def _py_if_stmt(cond, body, orelse):
1415  """Overload of if_stmt that executes a Python if statement."""
1416  return body() if cond else orelse()
1417