• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Control flow statements: loops, conditionals, etc.
16
17Python 2 compatibility version. Not maintained.
18
19Note: most of these operators accept pairs of get_state/set_state functions, to
20capture mutations that the corresponding code blocks might make. These
21mutations only need to be captured when staging the control flow, and they just
22work when reverting to Python behavior.
23
24__Examples__
25
26```
27while cond:
28  self.x += i
29```
30
31When the functionalized version is executed as a Python loop, it just works:
32
33```
34def loop_body():
35  self.x += i     # works as expected for Python loops
36```
37
38But it won't work for TF loops:
39
40```
41def loop_body():
42  self.x += i     # self.x has the wrong value!
43```
44
45get_state/set_state allow piping the mutations through the loop variables as
46well, in effect changing the loop body:
47
48```
49def loop_body(self_x):
50  self.x = self_x  # self.x now has the proper value
51  self.x += i      # the original block
52  self_x = self.x  # write self.x back into the loop vars
53  return self_x
54
55self_x = tf.while_loop(...)
56self.x = self_x    # the result is not properly captured
57```
58"""
59
60from __future__ import absolute_import
61from __future__ import division
62from __future__ import print_function
63
64import functools
65
66import numpy as np
67
68from tensorflow.python.autograph.operators import py_builtins
69from tensorflow.python.autograph.operators import variables
70from tensorflow.python.autograph.utils import ag_logging
71from tensorflow.python.autograph.utils import misc
72from tensorflow.python.autograph.utils import tensors
73from tensorflow.python.data.experimental.ops import take_while_ops
74from tensorflow.python.data.ops import dataset_ops
75from tensorflow.python.data.ops import iterator_ops
76from tensorflow.python.framework import constant_op
77from tensorflow.python.framework import dtypes
78from tensorflow.python.framework import func_graph
79from tensorflow.python.framework import ops
80from tensorflow.python.framework import tensor_util
81from tensorflow.python.ops import array_ops
82from tensorflow.python.ops import control_flow_ops
83from tensorflow.python.ops import math_ops
84from tensorflow.python.ops import tensor_array_ops
85from tensorflow.python.ops.ragged import ragged_tensor
86from tensorflow.python.util import lazy_loader
87from tensorflow.python.util import nest
88
89
90# TODO(b/145618471): Remove this dependency.
91# Lazy import to work around circular dependencies
92input_lib = lazy_loader.LazyLoader(
93    'input_lib', globals(),
94    'tensorflow.python.distribute.input_lib')
95
96LIMIT_PYTHON_ITERATIONS = True
97PYTHON_MAX_ITERATIONS = 100000000  # Fails in about one minute for empty loops.
98WARN_INEFFICIENT_UNROLL = True
99INEFFICIENT_UNROLL_MIN_ITERATIONS = 3000
100INEFFICIENT_UNROLL_MIN_OPS = 1
101
102
103def _disallow_undefs_into_loop(*values):
104  """Ensures that all values in the state are defined when entering a loop."""
105  undefined = [v for v in values if isinstance(v, variables.Undefined)]
106  if undefined:
107    raise ValueError(
108        '{} must be defined before the loop.'.format(
109            ','.join(s.symbol_name for s in undefined)))
110  for value in values:
111    if isinstance(value, variables.UndefinedReturnValue):
112      # Assumption: the loop will only capture the variable which tracks the
113      # return value if the loop contained a return statement.
114      # TODO(mdan): This should be checked at the place where return occurs.
115      raise ValueError(
116          'return statements are not supported within a TensorFlow loop.')
117
118
119def _is_subshape(left, right):
120  """Returns True if left shape is at least as specific as right shape."""
121  # TODO(mdan): This code should be in TensorShape.
122  # Note: this is not the same as TensorShape.is_compatible_with, which is
123  # symmetric.
124  # This code also duplicates _ShapeLessThanOrEqual from  control_flow_ops.py.
125  if right.dims is None:
126    return True
127  if left.ndims != right.ndims:
128    return False
129  for ldim, rdim in zip(left.dims, right.dims):
130    if rdim.value is not None and ldim.value != rdim.value:
131      return False
132  return True
133
134
135# TODO(mdan): Remove these verifications once TF ops can properly report names.
136def _verify_single_loop_var(
137    name, check_shape, init, entry, exit_, shape_invariant):
138  """Verifies whether the initial, entry and exit values are consistent."""
139  if isinstance(init, (bool, int, float, str, np.ndarray)):
140    init = ops.convert_to_tensor_v2(init)
141  if isinstance(entry, (bool, int, float, str, np.ndarray)):
142    entry = ops.convert_to_tensor_v2(entry)
143  if isinstance(exit_, (bool, int, float, str)):
144    exit_ = ops.convert_to_tensor_v2(exit_)
145
146  if (not tensor_util.is_tf_type(entry) or
147      not tensor_util.is_tf_type(exit_)):
148    return
149
150  # TODO(mdan): Properly account for CompositeTensors.
151  if (not hasattr(entry, 'dtype') or
152      not hasattr(exit_, 'dtype')):
153    return
154  if (not hasattr(entry, 'shape') or
155      not hasattr(exit_, 'shape')):
156    return
157
158  if entry.dtype != exit_.dtype:
159    raise TypeError(
160        '"{}" has dtype {} before the loop, but dtype {} after one'
161        ' iteration. TensorFlow control flow requires it stays the'
162        ' same.'.format(
163            name,
164            entry.dtype.name,
165            exit_.dtype.name,
166        ))
167  if check_shape:
168    exit_shape = exit_.shape
169    if shape_invariant is None:
170      entry_shape = entry.shape
171      if not _is_subshape(exit_shape, entry_shape):
172        raise ValueError(
173            '"{}" has shape {} before the loop, but shape {} after one'
174            ' iteration. Use tf.autograph.experimental.set_loop_options to set'
175            ' shape invariants.'.format(name, entry_shape, exit_shape))
176    else:
177      init_shape = init.shape
178      if not _is_subshape(init_shape, shape_invariant):
179        raise ValueError(
180            '"{}" has shape {} before the loop, which does not conform with'
181            ' the shape invariant {}.'.format(name, init_shape,
182                                              shape_invariant))
183      if not _is_subshape(exit_shape, shape_invariant):
184        raise ValueError(
185            '"{}" has shape {} after the loop, which does not conform with'
186            ' the shape invariant {}.'.format(
187                name, exit_shape, shape_invariant))
188
189
190def _verify_tf_loop_vars(init_vars,
191                         iter_entry_vars,
192                         iter_exit_vars,
193                         symbol_names,
194                         opts,
195                         check_shapes=True):
196  """Verifies loop variables for consistency."""
197  if check_shapes and 'shape_invariants' in opts:
198    shape_invariants = opts['shape_invariants']
199  else:
200    shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars)
201
202  named_vars = zip(symbol_names, init_vars, iter_entry_vars, iter_exit_vars,
203                   shape_invariants)
204  for name, init, entry, exit_, invariant in named_vars:
205    try:
206      nest.assert_same_structure(entry, exit_, expand_composites=True)
207    except (ValueError, TypeError) as e:
208      raise TypeError('"{}" does not have the same nested structure after one'
209                      ' iteration.\n\n{}'.format(name, e))
210    if invariant is not None:
211      try:
212        nest.assert_same_structure(init, invariant, expand_composites=False)
213      except (ValueError, TypeError) as e:
214        raise TypeError('"{}" does not have the same nested structure as its'
215                        ' corresponding shape invariant.\n\n{}'.format(name, e))
216
217    nest.map_structure(
218        functools.partial(_verify_single_loop_var, name, check_shapes), init,
219        entry, exit_, invariant)
220
221
222def _verify_single_cond_var(name, body_var, orelse_var):
223  """Verifies whether body_var and orelse_var are consistent."""
224  if isinstance(body_var, (bool, int, float, str)):
225    body_var = ops.convert_to_tensor_v2(body_var)
226
227  if isinstance(orelse_var, (bool, int, float, str)):
228    orelse_var = ops.convert_to_tensor_v2(orelse_var)
229
230  if (not tensor_util.is_tf_type(body_var) or
231      not tensor_util.is_tf_type(orelse_var)):
232    return
233
234  # TODO(mdan): Properly account for CompositeTensors.
235  if (not hasattr(body_var, 'dtype') or
236      not hasattr(orelse_var, 'dtype')):
237    return
238
239  if body_var.dtype != orelse_var.dtype:
240    raise TypeError(
241        '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE'
242        ' branch. TensorFlow control flow requires that they are the'
243        ' same.'.format(name, body_var.dtype.name,
244                        orelse_var.dtype.name))
245
246
247def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names):
248  """Verifies variables manipulated by a conditional for consistency."""
249  basic_body_vars, composite_body_vars = body_vars
250  basic_orelse_vars, composite_orelse_vars = orelse_vars
251  assert isinstance(composite_body_vars, tuple)
252  assert isinstance(composite_orelse_vars, tuple)
253
254  # TODO(kkb): Make this more consistent.
255  # The basic outputs should always be a tuple.
256  if not isinstance(basic_body_vars, tuple):
257    basic_body_vars = (basic_body_vars,)
258  if not isinstance(basic_orelse_vars, tuple):
259    basic_orelse_vars = (basic_orelse_vars,)
260
261  body_vars = basic_body_vars + composite_body_vars
262  orelse_vars = basic_orelse_vars + composite_orelse_vars
263
264  named_vars = zip(symbol_names, body_vars, orelse_vars)
265  for name, body_var, orelse_var in named_vars:
266    try:
267      nest.assert_same_structure(
268          body_var, orelse_var, expand_composites=True)
269    except (ValueError, TypeError) as e:
270      raise TypeError(
271          '"{}" does not have the same nested structure in the TRUE and FALSE'
272          ' branches.\n\n{}'.format(name, str(e)))
273
274    nest.map_structure(
275        functools.partial(_verify_single_cond_var, name), body_var, orelse_var)
276
277
278def for_stmt(iter_,
279             extra_test,
280             body,
281             get_state,
282             set_state,
283             init_vars,
284             basic_symbol_names,
285             composite_symbol_names,
286             opts):
287  """Functional form of a for statement.
288
289  The loop operates on a state, which includes all symbols that are
290  variant across loop iterations, excluding the iterate as well as the
291  variables local to the loop.
292
293  For example, given the loop below that calculates the geometric and
294  arithmetic means or some numbers:
295
296    geo_mean = 1
297    arith_mean = 0
298    for i in range(n):
299      a = numbers[i]
300      geo_mean *= a
301      arith_mean += a
302
303  The state is represented by the variables geo_mean and arith_mean. The
304  argument for initial_state may contain the tuple (1, 0), the body will
305  include the arguments geo_mean and arith_mean and will return a tuple
306  representing the new values for geo_mean and respectively arith_mean.
307
308  Args:
309    iter_: The entity being iterated over.
310    extra_test: Callable with the state as arguments, and boolean return type.
311      An additional loop condition.
312    body: Callable with the iterate and the state as arguments, and state as
313      return type. The actual loop body.
314    get_state: Additional callable which can capture additional state (such as
315      the values of composite symbols). This is only useful when staging the
316      loop.
317    set_state: Additional callable which save values captured by get_state back
318      into the Python environment. This is only useful when staging the loop.
319    init_vars: Tuple containing the initial state.
320    basic_symbol_names: Tuple containing basic loop var names.
321    composite_symbol_names: Tuple containing composite loop var names.
322    opts: Optional dict of extra loop parameters.
323
324  Returns:
325    Tuple containing the final state.
326  """
327  if tensor_util.is_tf_type(iter_):
328    if tensors.is_range_tensor(iter_):
329      return _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
330                                init_vars, basic_symbol_names,
331                                composite_symbol_names, opts)
332    else:
333      return _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
334                                    set_state, init_vars, basic_symbol_names,
335                                    composite_symbol_names, opts)
336
337  if isinstance(iter_, dataset_ops.DatasetV2):
338    return _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state,
339                                init_vars, basic_symbol_names,
340                                composite_symbol_names, opts)
341
342  if isinstance(iter_, iterator_ops.OwnedIterator):
343    return _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state,
344                                 init_vars, basic_symbol_names,
345                                 composite_symbol_names, opts)
346
347  if isinstance(iter_, ragged_tensor.RaggedTensor):
348    return _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
349                               init_vars, basic_symbol_names,
350                               composite_symbol_names, opts)
351
352  if isinstance(iter_, input_lib.DistributedIterator):
353    raise NotImplementedError(
354        'distributed iterators not supported yet, use the distributed dataset'
355        ' directly')
356
357  if isinstance(iter_, input_lib.DistributedDataset):
358    return _tf_distributed_dataset_for_stmt(iter_, extra_test, body, init_vars)
359
360  return _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars)
361
362
363def _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars):
364  """Overload of for_stmt that executes a Python for loop."""
365  del get_state, set_state
366  state = init_vars
367
368  if extra_test is not None:
369    if extra_test(*state):
370      for target in iter_:
371        state = body(target, *state)
372        if not extra_test(*state):
373          break
374
375  else:
376    for target in iter_:
377      state = body(target, *state)
378
379  return state
380
381
382def _known_len_tf_for_stmt(iter_,
383                           extra_test,
384                           body,
385                           get_state,
386                           set_state,
387                           init_vars,
388                           basic_symbol_names,
389                           composite_symbol_names,
390                           opts):
391  """Overload of for_stmt that iterates over TF entities that admit a length."""
392  _disallow_undefs_into_loop(*init_vars)
393
394  n = py_builtins.len_(iter_)
395  # TODO(b/117628877): Revisit performance once XLA has the necessary support.
396  # Note: using a TensorArray creates an extra copy, but can calculate
397  # gradients more efficiently than StridedSlice.
398  ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
399  iter_ = ta.unstack(iter_)
400
401  def while_body(iterate_index, *loop_vars):
402    """Main loop body."""
403    iterate = iter_.read(iterate_index)
404    new_vars = body(iterate, *loop_vars)
405
406    loop_vars = (iterate_index + 1,)
407    if new_vars:
408      loop_vars += new_vars
409
410    return loop_vars
411
412  def while_cond(iterate_index, *loop_vars):
413    if extra_test is not None:
414      return control_flow_ops.cond(iterate_index < n,
415                                   lambda: extra_test(*loop_vars),
416                                   lambda: False)
417    return iterate_index < n
418
419  opts['maximum_iterations'] = n
420
421  results = _tf_while_stmt(
422      while_cond,
423      while_body,
424      get_state,
425      set_state,
426      (array_ops.zeros_like(n),) + init_vars,
427      ('<internal iterate>',) + basic_symbol_names,
428      composite_symbol_names,
429      opts,
430  )
431
432  # Note: the iteration index is not returned by the while loop, however
433  # if a symbol with the same name exists outside the loop, it will be captured
434  # by the loop variables and ultimately updated correctly.
435  if isinstance(results, (tuple, list)):
436    assert len(results) >= 1  # Has at least the iterate.
437    if len(results) > 1:
438      results = results[1:]
439  else:
440    results = ()
441
442  return results
443
444
445def _tf_ragged_for_stmt(iter_,
446                        extra_test,
447                        body,
448                        get_state,
449                        set_state,
450                        init_vars,
451                        basic_symbol_names,
452                        composite_symbol_names,
453                        opts):
454  """Overload of for_stmt that iterates over TF ragged tensors."""
455  _disallow_undefs_into_loop(*init_vars)
456
457  # TODO(mdan): Move this into len()? Requires eager support.
458  if iter_.shape and iter_.shape[0] is not None:
459    n = iter_.shape[0]
460  else:
461    n = iter_.row_lengths()[0]
462
463  opts['maximum_iterations'] = n
464
465  def while_body(iterate_index, *loop_vars):
466    """Main loop body."""
467    iterate = iter_[iterate_index]
468    new_vars = body(iterate, *loop_vars)
469
470    loop_vars = (iterate_index + 1,)
471    if new_vars:
472      loop_vars += new_vars
473
474    return loop_vars
475
476  def while_cond(iterate_index, *loop_vars):
477    if extra_test is not None:
478      return control_flow_ops.cond(
479          iterate_index < n,
480          lambda: extra_test(*loop_vars),
481          lambda: False,
482      )
483    return iterate_index < n
484
485  opts['maximum_iterations'] = n
486
487  results = _tf_while_stmt(
488      while_cond,
489      while_body,
490      get_state,
491      set_state,
492      (array_ops.zeros_like(n),) + init_vars,
493      ('<internal iterate>',) + basic_symbol_names,
494      composite_symbol_names,
495      opts,
496  )
497
498  if isinstance(results, (tuple, list)):
499    assert len(results) >= 1  # Has at least the iterate.
500    if len(results) > 1:
501      results = results[1:]
502  else:
503    results = ()
504
505  return results
506
507
508def _tf_range_for_stmt(iter_,
509                       extra_test,
510                       body,
511                       get_state,
512                       set_state,
513                       init_vars,
514                       basic_symbol_names,
515                       composite_symbol_names,
516                       opts):
517  """Overload of for_stmt that iterates over a TF range (and elides it)."""
518  _disallow_undefs_into_loop(*init_vars)
519
520  start, limit, delta = iter_.op.inputs
521
522  def while_body(iterate, *loop_vars):
523    new_vars = body(iterate, *loop_vars)
524    loop_vars = (iterate + delta,)
525
526    if new_vars:
527      loop_vars += new_vars
528
529    return loop_vars
530
531  def while_cond(iterate, *loop_vars):
532    """Cond function for `tf.while_loop`."""
533    main_test = math_ops.logical_or(
534        math_ops.logical_and(delta >= 0, iterate < limit),
535        math_ops.logical_and(delta < 0, iterate > limit))
536    if extra_test is not None:
537      return control_flow_ops.cond(
538          main_test,
539          lambda: extra_test(*loop_vars),
540          lambda: False,
541      )
542    return main_test
543
544  opts['maximum_iterations'] = math_ops.cast(
545      misc.get_range_len(start, limit, delta), dtypes.int32)
546
547  results = _tf_while_stmt(
548      while_cond,
549      while_body,
550      get_state,
551      set_state,
552      (start,) + init_vars,
553      ('<internal iterate>',) + basic_symbol_names,
554      composite_symbol_names,
555      opts,
556  )
557
558  # Note: the iteration index is not returned by the while loop, however
559  # if a symbol with the same name exists outside the loop, it will be captured
560  # by the loop variables and ultimately updated correctly.
561  if isinstance(results, (tuple, list)):
562    assert len(results) >= 1  # Has at least the iterate.
563    if len(results) > 1:
564      results = results[1:]
565  else:
566    results = ()
567
568  return results
569
570
571def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state,
572                          init_vars, basic_symbol_names,
573                          composite_symbol_names, opts):
574  """Overload of for_stmt that iterates over TF Iterators. See for_loop."""
575  _disallow_undefs_into_loop(*init_vars)
576
577  def while_body_actual(opt_iterate, *loop_vars):
578    """Actual main loop body."""
579    new_vars = body(opt_iterate.get_value(), *loop_vars)
580    # TODO(mdan): Fix this inconsistency in the converter.
581    if new_vars is None:
582      new_vars = ()
583    # Note: this verification duplicates that perfrmed in tf_while_stmt,
584    # but needs to be done earlier to prevent the tf.cond inside while_body
585    # from blowing up first.
586    _verify_tf_loop_vars(init_vars, loop_vars, new_vars,
587                         basic_symbol_names + composite_symbol_names, opts)
588    return new_vars
589
590  def while_body(has_next, *loop_vars):
591    """Main loop body."""
592    opt_iterate = itr.get_next_as_optional()
593    has_next = opt_iterate.has_value()
594
595    if not init_vars:
596      # cond_v2 requires at least one state tensor in V1.
597      dummy_state = (constant_op.constant(()),)
598    else:
599      dummy_state = ()
600
601    # TODO(mdan): If tf.while_loop supported Optional, this could be avoided.
602    new_vars = control_flow_ops.cond(
603        has_next,
604        lambda: dummy_state + while_body_actual(opt_iterate, *loop_vars),
605        lambda: dummy_state + loop_vars,
606    )
607
608    if dummy_state:
609      new_vars = new_vars[1:]
610
611    return (has_next,) + new_vars
612
613  def while_cond(has_next, *loop_vars):
614    if extra_test is not None:
615      return control_flow_ops.cond(
616          has_next,
617          lambda: extra_test(*loop_vars),
618          lambda: False,
619      )
620    return has_next
621
622  final_vars = _tf_while_stmt(
623      while_cond,
624      while_body,
625      get_state,
626      set_state,
627      (True,) + init_vars,
628      ('<internal has_next>',) + basic_symbol_names,
629      composite_symbol_names,
630      opts,
631  )
632  return final_vars[1:]
633
634
635def _tf_dataset_for_stmt(ds, extra_test, body, get_state, set_state, init_vars,
636                         basic_symbol_names, composite_symbol_names, opts):
637  """Overload of for_stmt that iterates over TF Datasets."""
638  _disallow_undefs_into_loop(*init_vars)
639
640  if extra_test is not None:
641    assert init_vars, 'Lowering should always add state.'
642    return _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
643                                             set_state, init_vars,
644                                             basic_symbol_names,
645                                             composite_symbol_names, opts)
646
647  return _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state,
648                                         init_vars, basic_symbol_names,
649                                         composite_symbol_names, opts)
650
651
652def _general_purpose_scan(ds, init_state, body):
653  """Variant of Dataset.scan with semantics of general-purpose computation."""
654  # Datasets are typically intended for data preprocessing. However, in
655  # autograph loops they usually appear as general-purpose computations (for
656  # example, a custom training loop). These two use cases require significantly
657  # different optimization policies, the most important of which is the device
658  # placement. The flag override for use_default_device below instructs the
659  # runtime to treat the computation as general-purpose, rather than data
660  # preprocessing.
661  # TODO(mdan): s/use_default_device/specialize_for_input_pipeline.
662  # TODO(mdan): Don't use private symbols.
663  # pylint:disable=protected-access
664  return dataset_ops._ScanDataset(
665      ds, init_state, body, use_default_device=False)
666
667
668def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
669                                      set_state, init_vars, basic_symbol_names,
670                                      composite_symbol_names, opts):
671  """Overload of _dataset_for_stmt with early stopping. See for_stmt."""
672
673  # TODO(mdan): Simplify this - following it is extremely difficult.
674
675  init_state = get_state()
676  aug_init_vars = init_vars, init_state
677
678  def scan_body(aug_vars, iterate):
679    """The main loop body wrapper. Only calculates the stop condition."""
680    loop_vars, state = aug_vars
681
682    def true_fn():
683      """Main path - stop condition is not set."""
684      set_state(state)
685      new_vars = body(iterate, *loop_vars)
686      new_state = get_state()
687      _verify_tf_loop_vars(
688          init_vars + init_state,
689          loop_vars + state,
690          new_vars + new_state,
691          basic_symbol_names + composite_symbol_names,
692          opts,
693          check_shapes=False)
694      return new_vars, new_state
695
696    extra_cond = extra_test(*loop_vars)
697    new_vars, new_state = control_flow_ops.cond(
698        extra_cond,
699        true_fn,
700        lambda: (loop_vars, state),
701    )
702
703    scan_outputs = new_vars, new_state, extra_cond
704    # Note: new_aug_vars is the actual state of scan; scan_outputs is its output
705    # (hence the redundancy).
706    # get_state will pull any mutations that body may have made.
707    new_aug_vars = new_vars, new_state
708    return new_aug_vars, scan_outputs
709
710  def take_while_predicate(unused_loop_vars, unused_state, extra_cond):
711    return extra_cond
712
713  def reduce_body(unused_aug_vars, scan_outputs):
714    output_aug_vars, output_state, extra_cond = scan_outputs
715    del extra_cond
716    return output_aug_vars, output_state
717
718  ds = _general_purpose_scan(ds, aug_init_vars, scan_body)
719  ds = ds.apply(take_while_ops.take_while(take_while_predicate))
720  final_aug_vars = ds.reduce(aug_init_vars, reduce_body)
721  final_vars, final_state = final_aug_vars
722  set_state(final_state)
723  return final_vars
724
725
726def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars,
727                                    basic_symbol_names, composite_symbol_names,
728                                    opts):
729  """Overload of _dataset_for_stmt without early stopping. See for_stmt."""
730  init_state = get_state()
731  assert isinstance(init_vars, tuple)
732  assert isinstance(init_state, tuple)
733
734  symbol_names = basic_symbol_names + composite_symbol_names
735
736  # Workaround for Dataset.reduce not allowing empty state tensors - create
737  # a dummy state variable that remains unused.
738  # TODO(mdan): reduce should allow and match empty structures.
739  no_vars = not init_vars
740  no_state = not init_state
741
742  if no_vars:
743    init_vars = (constant_op.constant(0),)
744    symbol_names = ('<internal dummy>',) + symbol_names
745  if no_state:
746    init_state = (constant_op.constant(0),)
747    symbol_names = symbol_names + ('<internal dummy>',)
748
749  def scan_body(aug_vars, iterate):
750    """The main loop body wrapper."""
751    loop_vars, state = aug_vars
752    if not no_state:
753      set_state(state)
754
755    if no_vars:
756      body(iterate)
757      new_vars = loop_vars
758    else:
759      new_vars = body(iterate, *loop_vars)
760
761    if no_state:
762      new_state = state
763    else:
764      new_state = get_state()
765
766    _verify_tf_loop_vars(
767        init_vars + init_state,
768        loop_vars + state,
769        new_vars + new_state,
770        symbol_names,
771        opts,
772        check_shapes=False)
773
774    scan_outputs = new_vars, new_state
775    # Note: new_aug_vars is the actual state of scan; scan_outputs is its output
776    # (hence the redundancy).
777    # get_state will pull any mutations that body may have made.
778    new_aug_vars = new_vars, new_state
779    return new_aug_vars, scan_outputs
780
781  def reduce_body(unused_aug_vars, scan_outputs):
782    output_aug_vars, output_state = scan_outputs
783    return output_aug_vars, output_state
784
785  aug_vars = init_vars, get_state()
786  ds = _general_purpose_scan(ds, aug_vars, scan_body)
787  final_vars, final_state = ds.reduce(aug_vars, reduce_body)
788  set_state(final_state)
789
790  if no_vars:
791    return ()
792  return final_vars
793
794
795def _tf_distributed_dataset_for_stmt(iter_, extra_test, body, init_state):
796  """Overload of for..in statement that iterates over the input."""
797  _disallow_undefs_into_loop(*init_state)
798
799  if extra_test is not None:
800    raise NotImplementedError(
801        'break and return statements are not yet supported in '
802        'for ... in distributed input loops.')
803
804  def reduce_body(state, iterate):
805    new_state = body(iterate, *state)
806    return new_state
807
808  if init_state:
809    return iter_.reduce(init_state, reduce_body)
810
811  def reduce_body_with_dummy_state(state, iterate):
812    reduce_body((), iterate)
813    return state
814  iter_.reduce((constant_op.constant(0),), reduce_body_with_dummy_state)
815  return ()
816
817
818def while_stmt(test,
819               body,
820               get_state,
821               set_state,
822               init_vars,
823               basic_symbol_names,
824               composite_symbol_names,
825               opts):
826  """Functional form of a while statement.
827
828  The loop operates on a so-called state, which includes all symbols that are
829  variant across loop iterations. In what follows we refer to state as either
830  a tuple of entities that represent an actual state, or a list of arguments
831  of the corresponding types.
832
833  Args:
834    test: Callable with the state as arguments, and boolean return type. The
835      loop condition.
836    body: Callable with the state as arguments, and state as return type. The
837      actual loop body.
838    get_state: Additional callable which can capture additional state (such as
839      the values of composite symbols). This is only useful when staging the
840      loop.
841    set_state: Additional callable which save values captured by get_state back
842      into the Python environment. This is only useful when staging the loop.
843    init_vars: Tuple containing the initial state.
844    basic_symbol_names: Tuple containing basic loop var names.
845    composite_symbol_names: Tuple containing composite loop var names.
846    opts: Optional dict of extra loop parameters.
847
848  Returns:
849    Tuple containing the final state.
850  """
851
852  # Evaluate the initial test once in order to do the dispatch. The evaluation
853  # is isolated to minimize unwanted side effects.
854  # TODO(mdan): Do a full iteration - some state types might lower to Tensor.
855  with func_graph.FuncGraph('tmp').as_default():
856    init_test = test(*init_vars)
857
858  # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine
859  # with the re-evaluation of `test` that `_tf_while_stmt` will make.
860  if tensors.is_dense_tensor(init_test):
861    return _tf_while_stmt(test, body, get_state, set_state, init_vars,
862                          basic_symbol_names, composite_symbol_names, opts)
863
864  # Normal Python: We already consumed one evaluation of `test`; consistently,
865  # unroll one iteration before dispatching to a normal loop.
866  # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt?
867  if not init_test:
868    return init_vars
869  init_vars = body(*init_vars)
870
871  return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
872
873
874def _shape_invariants_mapping_to_positional_list(mapping, keys):
875  # The keys are not expected to be hashable.
876  mapping = {id(k): (k, v) for k, v in mapping}
877  result = []
878  for k in keys:
879    map_key, map_val = mapping.get(id(k), (None, None))
880    result.append(map_val if map_key is k else None)
881  return tuple(result)
882
883
884def _tf_while_stmt(test, body, get_state, set_state, init_vars,
885                   basic_symbol_names, composite_symbol_names, opts):
886  """Overload of while_stmt that stages a TF while_stmt."""
887  _disallow_undefs_into_loop(*init_vars)
888
889  aug_init_vars = init_vars + get_state()
890
891  # TODO(mdan): Simplify this.
892  loop_vars_slice = slice(len(init_vars))
893  state_slice = slice(len(init_vars), None)
894
895  def aug_test(*aug_loop_vars):
896    state = aug_loop_vars[state_slice]
897    set_state(state)
898    return test(*aug_loop_vars[loop_vars_slice])
899
900  def aug_body(*aug_loop_vars):
901    """Main loop body."""
902    state = aug_loop_vars[state_slice]
903    set_state(state)
904    loop_vars = body(*aug_loop_vars[loop_vars_slice])
905    new_state = loop_vars + get_state()
906    _verify_tf_loop_vars(aug_init_vars, aug_loop_vars, new_state,
907                         basic_symbol_names + composite_symbol_names, opts)
908
909    return new_state
910
911  # Non-v2 while_loop unpacks the results when there is only one return value.
912  # This enforces consistency across versions.
913  opts['return_same_structure'] = True
914
915  if 'shape_invariants' in opts:
916    opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
917        opts['shape_invariants'], aug_init_vars)
918
919  final_aug_vars = control_flow_ops.while_loop(aug_test, aug_body,
920                                               aug_init_vars, **opts)
921  final_state = final_aug_vars[state_slice]
922  set_state(final_state)
923  return final_aug_vars[loop_vars_slice]
924
925
926class _PythonLoopChecker(object):
927  """Verifies Python loops for TF-specific limits."""
928
929  def __init__(self):
930    self.iterations = 0
931    self.check_inefficient_unroll = WARN_INEFFICIENT_UNROLL
932
933    # Triggered when we decided to test the op counts.
934    self.check_op_count_after_iteration = False
935
936  def _get_ops(self):
937    return ops.get_default_graph().get_operations()
938
939  def _check_unroll_limits(self):
940    if LIMIT_PYTHON_ITERATIONS and self.iterations > PYTHON_MAX_ITERATIONS:
941      raise ValueError('iteration limit exceeded')
942
943  def _stop_checking_inefficient_unroll(self):
944    self.check_inefficient_unroll = False
945    self.ops_before_iteration = None
946
947  def _verify_ineffcient_unroll(self):
948    """Checks for possibly-inefficient creation of ops in a Python loop."""
949    assert self.ops_before_iteration is not None
950    ops_after_iteration = self._get_ops()
951    new_ops = tuple(
952        op for op in ops_after_iteration if op not in self.ops_before_iteration)
953
954    if len(new_ops) < INEFFICIENT_UNROLL_MIN_OPS:
955      return False
956
957    # TODO(mdan): Add location information.
958    ag_logging.warn(
959        'TensorFlow ops are being created in a Python loop with large number'
960        ' of iterations. This can lead to slow startup. Did you mean to use a'
961        ' TensorFlow loop? For example, `while True:` is a Python loop, and'
962        ' `while tf.constant(True):` is a TensorFlow loop. The following'
963        ' ops were created after iteration %s: %s', self.iterations, new_ops)
964    return True
965
966  def before_iteration(self):
967    """Called before each iteration in a Python loop."""
968    if (self.check_inefficient_unroll and
969        self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS):
970      self.ops_before_iteration = self._get_ops()
971      self.check_op_count_after_iteration = True
972
973  def after_iteration(self):
974    """Called after each iteration in a Python loop."""
975    self.iterations += 1
976
977    self._check_unroll_limits()
978
979    if self.check_inefficient_unroll and self.check_op_count_after_iteration:
980      did_warn = self._verify_ineffcient_unroll()
981      if did_warn:
982        self._stop_checking_inefficient_unroll()  # Only warn once.
983      elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3:
984        # Once deciding to check the op counts, only do it for a few iterations.
985        self._stop_checking_inefficient_unroll()
986
987
988def _py_while_stmt(test, body, get_state, set_state, init_vars, opts):
989  """Overload of while_stmt that executes a Python while loop."""
990  del opts, get_state, set_state
991
992  if __debug__:
993    checker = _PythonLoopChecker()
994
995  loop_vars = init_vars
996  while test(*loop_vars):
997
998    if __debug__:
999      checker.before_iteration()
1000
1001    loop_vars = body(*loop_vars)
1002
1003    if __debug__:
1004      checker.after_iteration()
1005
1006  return loop_vars
1007
1008
1009def if_stmt(cond,
1010            body,
1011            orelse,
1012            get_state,
1013            set_state,
1014            basic_symbol_names,
1015            composite_symbol_names):
1016  """Functional form of an if statement.
1017
1018  Args:
1019    cond: Boolean.
1020    body: Callable with no arguments, and outputs of the positive (if) branch as
1021      return type.
1022    orelse: Callable with no arguments, and outputs of the negative (else)
1023      branch as return type.
1024    get_state: Function that returns a tuple containing the values of all
1025      composite symbols modified within the conditional. This allows access to
1026      state that branches may mutate through side effects. This function is not
1027      needed and should not be called when dispatching to code matching Python's
1028      default semantics. This is useful for checkpointing to avoid unintended
1029      side-effects when staging requires evaluating all code-paths.
1030    set_state: Function to set the values of all composite symbols modified
1031      within the conditional. This is the complement to get_state, used to
1032      restore checkpointed values. The single argument a tuple containing values
1033      for each composite symbol that may be modified in a branch of the
1034      conditional. The is usually the result of a call to get_state.
1035    basic_symbol_names: Tuple containing basic loop var names.
1036    composite_symbol_names: Tuple containing composite loop var names.
1037
1038  Returns:
1039    Tuple containing the statement outputs.
1040  """
1041  # Note: tf.cond doesn't support SparseTensor.
1042  if tensors.is_dense_tensor(cond):
1043    return tf_if_stmt(cond, body, orelse, get_state, set_state,
1044                      basic_symbol_names, composite_symbol_names)
1045  else:
1046    return _py_if_stmt(cond, body, orelse)
1047
1048
1049def tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names,
1050               composite_symbol_names):
1051  """Overload of if_stmt that stages a TF cond."""
1052  body = _wrap_disallow_undefs_from_cond(body, branch_name='if')
1053  orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else')
1054  body = _isolate_state(body, get_state, set_state)
1055  orelse = _isolate_state(orelse, get_state, set_state)
1056
1057  # `state` currently includes the values of any composite symbols (e.g. `a.b`)
1058  # composites modified by the loop. `final_vars` includes the values of basic
1059  # symbols (e.g. `a`) which cannot be passed by reference and must be returned.
1060  # See _isolate_state.
1061  # TODO(mdan): We should minimize calls to get/set_state.
1062
1063  body_branch = 0
1064  orelse_branch = 1
1065  result = [None, None]
1066
1067  def error_checking_body():
1068    result[body_branch] = body()
1069    if result[orelse_branch] is not None:
1070      _verify_tf_cond_vars(result[body_branch], result[orelse_branch],
1071                           basic_symbol_names + composite_symbol_names)
1072    return result[body_branch]
1073
1074  def error_checking_orelse():
1075    result[orelse_branch] = orelse()
1076    if result[body_branch] is not None:
1077      _verify_tf_cond_vars(result[body_branch], result[orelse_branch],
1078                           basic_symbol_names + composite_symbol_names)
1079    return result[orelse_branch]
1080
1081  final_vars, final_state = control_flow_ops.cond(cond, error_checking_body,
1082                                                  error_checking_orelse)
1083
1084  set_state(final_state)
1085
1086  return final_vars
1087
1088
1089def _isolate_state(func, get_state, set_state):
1090  """Wraps func to (best-effort) isolate state mutations that func may do.
1091
1092  The simplest example of state mutation is mutation of variables (via e.g.
1093  attributes), or modification of globals.
1094
1095  This allows us to more safely execute this function without worrying about
1096  side effects when the function wasn't normally expected to execute. For
1097  example, staging requires that the function is executed ahead of time, and
1098  we need to ensure its effects are not observed during normal execution.
1099
1100  Args:
1101    func: () -> Any
1102    get_state: () -> Any, returns the current state
1103    set_state: (Any) -> None, resets the state to the specified values.
1104      Typically the result of an earlier call to `get_state`.
1105
1106  Returns:
1107    Tuple[Any, Any], where the first element is the return value of `func`,
1108    and the second is the final state values.
1109  """
1110
1111  def wrapper():
1112    init_state = get_state()
1113    new_vars = func()
1114    # TODO(mdan): These should be copies, lest set_state might affect them.
1115    new_state = get_state()
1116    set_state(init_state)
1117    return new_vars, new_state
1118
1119  return wrapper
1120
1121
1122def _wrap_disallow_undefs_from_cond(func, branch_name):
1123  """Wraps conditional branch to disallow returning undefined symbols."""
1124
1125  def wrapper():
1126    """Calls function and raises an error if undefined symbols are returned."""
1127    results = func()
1128
1129    if isinstance(results, tuple):
1130      results_tuple = results
1131    else:
1132      results_tuple = results,
1133    undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)]
1134    if undefined:
1135      raise ValueError(
1136          'The following symbols must also be initialized in the {} branch: {}.'
1137          ' Alternatively, you may initialize them before the if'
1138          ' statement.'.format(branch_name,
1139                               tuple(s.symbol_name for s in undefined)))
1140
1141    for result in results_tuple:
1142      if isinstance(result, variables.UndefinedReturnValue):
1143        raise ValueError(
1144            'A value must also be returned from the {} branch. If a value is '
1145            'returned from one branch of a conditional a value must be '
1146            'returned from all branches.'.format(branch_name))
1147
1148    return results
1149
1150  return wrapper
1151
1152
1153def _py_if_stmt(cond, body, orelse):
1154  """Overload of if_stmt that executes a Python if statement."""
1155  return body() if cond else orelse()
1156