1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Control flow statements: loops, conditionals, etc. 16 17Note: most of these operators accept pairs of get_state/set_state functions, to 18capture mutations that the corresponding code blocks might make. These 19mutations only need to be captured when staging the control flow, and they just 20work when reverting to Python behavior. 21 22__Examples__ 23 24``` 25while cond: 26 self.x += i 27``` 28 29When the functionalized version is executed as a Python loop, it just works: 30 31``` 32def loop_body(): 33 self.x += i # works as expected for Python loops 34``` 35 36But it won't work for TF loops: 37 38``` 39def loop_body(): 40 self.x += i # self.x has the wrong value! 41``` 42 43get_state/set_state allow piping the mutations through the loop variables as 44well, in effect changing the loop body: 45 46``` 47def loop_body(self_x): 48 self.x = self_x # self.x now has the proper value 49 self.x += i # the original block 50 self_x = self.x # write self.x back into the loop vars 51 return self_x 52 53self_x = tf.while_loop(...) 54self.x = self_x # the result is not properly captured 55``` 56""" 57 58import functools 59import sys 60import traceback 61 62import numpy as np 63 64from tensorflow.python.autograph.operators import py_builtins 65from tensorflow.python.autograph.operators import variables 66from tensorflow.python.autograph.utils import ag_logging 67from tensorflow.python.autograph.utils import misc 68from tensorflow.python.autograph.utils import tensors 69from tensorflow.python.data.experimental.ops import take_while_ops 70from tensorflow.python.data.ops import dataset_ops 71from tensorflow.python.data.ops import iterator_ops 72from tensorflow.python.framework import constant_op 73from tensorflow.python.framework import dtypes 74from tensorflow.python.framework import errors_impl 75from tensorflow.python.framework import func_graph 76from tensorflow.python.framework import ops 77from tensorflow.python.framework import tensor_shape 78from tensorflow.python.framework import tensor_util 79from tensorflow.python.ops import array_ops 80from tensorflow.python.ops import control_flow_ops 81from tensorflow.python.ops import control_flow_util 82from tensorflow.python.ops import math_ops 83from tensorflow.python.ops import tensor_array_ops 84from tensorflow.python.ops.ragged import ragged_tensor 85from tensorflow.python.types import distribute 86from tensorflow.python.util import nest 87from tensorflow.python.util import variable_utils 88 89 90PYTHON_MAX_ITERATIONS = 100000000 # Fails in about one minute for empty loops. 91WARN_INEFFICIENT_UNROLL = True 92INEFFICIENT_UNROLL_MIN_ITERATIONS = 50000 93INEFFICIENT_UNROLL_MIN_OPS = 1 94 95 96# TODO(mdan): Use the custom operator pattern instead of type dispatch. 97# An example of this pattern is found in the implementation of distributed 98# datasets. Before it can be used though, we need to standardize the interface. 99 100 101def _is_none_or_undef(value): 102 """Tests whether a value is None or undefined. 103 104 AutoGraph represents undefined symbols using special objects of type Undefined 105 or UndefinedReturnValue. 106 107 Args: 108 value: value to test 109 110 Returns: 111 Boolean 112 """ 113 return ((value is None) 114 or isinstance(value, variables.UndefinedReturnValue) 115 or isinstance(value, variables.Undefined)) 116 117 118def _verify_tf_condition(cond, tag): 119 """Ensures that the condition can be used in a TF control flow.""" 120 extra_hint = 'to check for None, use `is not None`' 121 cond = ops.convert_to_tensor_v2(cond) 122 123 if cond.dtype != dtypes.bool: 124 raise ValueError( 125 'condition of {} expected to be `tf.bool` scalar, got {}' 126 '; to use as boolean Tensor, use `tf.cast`' 127 '; {}'.format(tag, cond, extra_hint)) 128 129 if cond.shape is None or cond.shape.ndims is None: 130 # TODO(mdan): Consider a explicit size check, if not too slow. 131 cond = array_ops.reshape(cond, ()) 132 133 elif cond.shape.ndims > 0: 134 known_dims = [d for d in cond.shape.as_list() if d is not None] 135 if np.prod(known_dims) > 1: 136 raise ValueError( 137 'condition of {} expected to be `tf.bool` scalar, got {}' 138 '; {}'.format(tag, cond, extra_hint)) 139 else: 140 cond = array_ops.reshape(cond, ()) 141 142 return cond 143 144 145def _verify_loop_init_vars(init_vars, 146 symbol_names, 147 first_iter_vars=None, 148 extra_message=None): 149 """Ensures that all values in the state are valid to use in a TF loop. 150 151 The init_vars may contain placeholder values derived from first_iter_vars. 152 153 Args: 154 init_vars: initial loop variables (as taken before entering the loop) 155 symbol_names: corresponding names of the initial loop variables 156 first_iter_vars: loop variables after one iteration of the loop 157 extra_message: an extra string to append to the error message, in case of 158 "undefined variable" errors (see variables.Undefined) 159 """ 160 if not symbol_names: 161 return 162 if first_iter_vars is None: 163 first_iter_vars = (None,) * len(symbol_names) 164 165 assert len(symbol_names) == len(init_vars) 166 assert len(symbol_names) == len(first_iter_vars) 167 for name, val, fi_val in zip(symbol_names, init_vars, first_iter_vars): 168 if isinstance(val, variables.UndefinedReturnValue): 169 if fi_val: 170 raise ValueError( 171 'the return value from a TensorFlow loop may only be a {}; got {}' 172 .format(LEGAL_LOOP_TYPES, type(fi_val))) 173 else: 174 # TODO(mdan): This can be handled by removing the return value. 175 raise NotImplementedError( 176 'a return statement cannot be placed inside this TensorFlow loop;' 177 ' this may happen if a return statement depends on a' 178 ' static Python condition such as a hyperparameter') 179 180 error_msg = None 181 if val is None: 182 error_msg = "'{}' may not be None before the loop".format(name) 183 elif isinstance(val, variables.Undefined): 184 error_msg = "'{}' must be defined before the loop".format(name) 185 if extra_message: 186 error_msg += '\n' + extra_message 187 188 if error_msg is not None: 189 raise ValueError(error_msg) 190 191 192def _is_subshape(left, right): 193 """Returns True if left shape is at least as specific as right shape.""" 194 # TODO(mdan): This code should be in TensorShape. 195 # Note: this is not the same as TensorShape.is_compatible_with, which is 196 # symmetric. 197 # This code also duplicates _ShapeLessThanOrEqual from control_flow_ops.py. 198 if right.dims is None: 199 return True 200 if left.ndims != right.ndims: 201 return False 202 for ldim, rdim in zip(left.dims, right.dims): 203 if rdim.value is not None and ldim.value != rdim.value: 204 return False 205 return True 206 207 208# TODO(mdan): Remove these verifications once TF ops can properly report names. 209def _verify_single_loop_var( 210 name, check_shape, init, entry, exit_, shape_invariant): 211 """Verifies whether the initial, entry and exit values are consistent.""" 212 assert entry is not None, "no TF op should set '{}' to None?".format(name) 213 if exit_ is None: 214 raise ValueError("'{}' is None at the end of the iteration.".format(name)) 215 216 if isinstance(init, (bool, int, float, str, np.ndarray)): 217 init = ops.convert_to_tensor_v2(init) 218 if isinstance(entry, (bool, int, float, str, np.ndarray)): 219 entry = ops.convert_to_tensor_v2(entry) 220 if isinstance(exit_, (bool, int, float, str, np.ndarray)): 221 exit_ = ops.convert_to_tensor_v2(exit_) 222 223 if (not tensor_util.is_tf_type(entry) or 224 not tensor_util.is_tf_type(exit_)): 225 return 226 227 # TODO(mdan): Properly account for CompositeTensors. 228 if (not hasattr(entry, 'dtype') or 229 not hasattr(exit_, 'dtype')): 230 return 231 if (not hasattr(entry, 'shape') or 232 not hasattr(exit_, 'shape')): 233 return 234 235 if entry.dtype != exit_.dtype: 236 raise TypeError( 237 "'{}' has dtype {} before the loop, but dtype {} after one" 238 ' iteration'.format( 239 name, 240 entry.dtype.name, 241 exit_.dtype.name, 242 )) 243 if check_shape: 244 exit_shape = exit_.shape 245 if shape_invariant is None: 246 entry_shape = entry.shape 247 if not _is_subshape(exit_shape, entry_shape): 248 raise ValueError( 249 "'{}' has shape {} before the loop, but shape {} after one" 250 ' iteration. Use tf.autograph.experimental.set_loop_options to set' 251 ' shape invariants.'.format(name, entry_shape, exit_shape)) 252 else: 253 init_shape = init.shape 254 if not _is_subshape(init_shape, shape_invariant): 255 raise ValueError( 256 "'{}' has shape {} before the loop, which does not conform with" 257 ' the shape invariant {}.'.format(name, init_shape, 258 shape_invariant)) 259 if not _is_subshape(exit_shape, shape_invariant): 260 raise ValueError( 261 "'{}' has shape {} after one iteration, which does not conform with" 262 ' the shape invariant {}.'.format( 263 name, exit_shape, shape_invariant)) 264 265 266def _verify_tf_loop_vars(init_vars, 267 iter_entry_vars, 268 iter_exit_vars, 269 symbol_names, 270 opts, 271 check_shapes=True): 272 """Verifies loop variables for consistency.""" 273 if check_shapes and 'shape_invariants' in opts: 274 shape_invariants = opts['shape_invariants'] 275 else: 276 shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars) 277 278 assert len(symbol_names) == len(shape_invariants) 279 assert len(symbol_names) == len(init_vars) 280 assert len(symbol_names) == len(iter_entry_vars) 281 assert len(symbol_names) == len(iter_exit_vars) 282 283 for i in range(len(symbol_names)): 284 name = symbol_names[i] 285 init = init_vars[i] 286 entry = iter_entry_vars[i] 287 exit_ = iter_exit_vars[i] 288 invariant = shape_invariants[i] 289 290 try: 291 nest.assert_same_structure(init, entry, expand_composites=True) 292 except (ValueError, TypeError): 293 # `Variable`s in `init` may be implicitly converted to `Tensor`s. Convert 294 # `ResourceVariable`s to Tensors so tf.nest.assert_same_structure 295 # won't break due to type spec mismatches between `ResourceVariable`s and 296 # `Tensor`s. 297 try: 298 init_tensors = variable_utils.convert_variables_to_tensors(init) 299 nest.assert_same_structure(init_tensors, entry, expand_composites=True) 300 except (ValueError, TypeError) as e: 301 raise TypeError("'{}' does not have the same nested structure after one" 302 ' iteration.\n\n{}'.format(name, e)) from e 303 304 try: 305 nest.assert_same_structure(entry, exit_, expand_composites=True) 306 except (ValueError, TypeError) as e: 307 raise TypeError("'{}' does not have the same nested structure after one" 308 ' iteration.\n\n{}'.format(name, e)) from e 309 if invariant is not None: 310 try: 311 nest.assert_same_structure(init, invariant, expand_composites=False) 312 except (ValueError, TypeError) as e: 313 raise TypeError("'{}' does not have the same nested structure as its" 314 ' corresponding shape invariant.\n\n{}'.format( 315 name, e)) from e 316 317 nest.map_structure( 318 functools.partial(_verify_single_loop_var, name, check_shapes), init, 319 entry, exit_, invariant) 320 321 322def verify_single_cond_var(name, body_var, orelse_var): 323 """Verifies whether body_var and orelse_var are consistent.""" 324 if body_var is None: 325 raise ValueError("'{}' is None at the end of the main branch.".format(name)) 326 if orelse_var is None: 327 raise ValueError( 328 "'{}' is None at the end of the else branch.".format(name)) 329 330 if isinstance(body_var, (bool, int, float, str, np.ndarray)): 331 body_var = ops.convert_to_tensor_v2(body_var) 332 333 if isinstance(orelse_var, (bool, int, float, str, np.ndarray)): 334 orelse_var = ops.convert_to_tensor_v2(orelse_var) 335 336 if (not tensor_util.is_tf_type(body_var) or 337 not tensor_util.is_tf_type(orelse_var)): 338 return 339 340 # TODO(mdan): Properly account for CompositeTensors. 341 if (not hasattr(body_var, 'dtype') or 342 not hasattr(orelse_var, 'dtype')): 343 return 344 345 if body_var.dtype != orelse_var.dtype: 346 raise TypeError( 347 "'{}' has dtype {} in the main branch, but dtype {} in the else" 348 ' branch'.format(name, body_var.dtype.name, 349 orelse_var.dtype.name)) 350 351 352def _verify_tf_cond_branch_vars(vars_, symbol_names, branch_name): 353 """Verifies variables output by a conditional branch for consistency.""" 354 for name, var_ in zip(symbol_names, vars_): 355 if isinstance(var_, variables.Undefined): 356 raise ValueError( 357 "'{}' must also be initialized in the {} branch".format( 358 name, branch_name)) 359 if isinstance(var_, variables.UndefinedReturnValue): 360 raise ValueError( 361 'the {} branch must also have a return statement.'.format( 362 branch_name)) 363 364 365def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names): 366 """Verifies variables manipulated by a conditional for consistency.""" 367 named_vars = zip(symbol_names, body_vars, orelse_vars) 368 369 for name, body_var, orelse_var in named_vars: 370 try: 371 nest.assert_same_structure(body_var, orelse_var, expand_composites=True) 372 except (ValueError, TypeError): 373 # One branch of cond could be a `Tensor`, while the other branch could be 374 # a `ResourceVariable`. Convert `ResourceVariable`s to `Tensor`s so 375 # assert_same_structure won't fail. 376 try: 377 body_var_tensors = variable_utils.convert_variables_to_tensors(body_var) 378 orelse_var_tensors = variable_utils.convert_variables_to_tensors( 379 orelse_var) 380 nest.assert_same_structure(body_var_tensors, orelse_var_tensors, 381 expand_composites=True) 382 except (ValueError, TypeError) as e: 383 raise TypeError( 384 "'{}' must have the same nested structure in the main and else" 385 ' branches:\n\n{}'.format(name, str(e))) from e 386 nest.map_structure( 387 functools.partial(verify_single_cond_var, name), body_var, orelse_var) 388 389 390def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): 391 """Functional form of a for statement. 392 393 The loop operates on a state, which includes all symbols that are 394 variant across loop iterations, excluding the variables local to the loop. 395 396 For example, given the loop below that calculates the geometric and 397 arithmetic means or some numbers: 398 399 ``` 400 geo_mean = 1 401 arith_mean = 0 402 for i in range(n): 403 a = numbers[i] 404 geo_mean *= a 405 arith_mean += a 406 ``` 407 408 The state is represented by the variables geo_mean and arith_mean. The 409 `extra_test`, `body`, `get_state` and `set_state` functions must bind to the 410 original `geo_mean` and `arith_mean` symbols, using `nonlocal`. 411 412 The inputs and outputs of the callables representing the loop blocks are not 413 explicit - instead, these functions must use nonlocal/global for side effects. 414 The inputs and outputs are instead controlled by the set_state/get_state 415 functions. 416 417 Args: 418 iter_: The entity being iterated over. 419 extra_test: Callable with boolean return type. An additional loop condition. 420 body: Callable representing the actual loop body. 421 get_state: Additional callable which can capture additional state (such as 422 the values of composite symbols). This is only useful when staging the 423 loop. 424 set_state: Additional callable which save values captured by get_state back 425 into the Python environment. This is only useful when staging the loop. 426 symbol_names: Tuple containing names of the loop variables returned by 427 get_state. 428 opts: Optional dict of extra loop parameters. 429 """ 430 if tensor_util.is_tf_type(iter_): 431 if tensors.is_range_tensor(iter_): 432 _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, 433 symbol_names, opts) 434 elif isinstance(iter_, ragged_tensor.RaggedTensor): 435 _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, 436 symbol_names, opts) 437 else: 438 _known_len_tf_for_stmt( 439 iter_, extra_test, body, get_state, set_state, symbol_names, opts) 440 441 elif isinstance(iter_, dataset_ops.DatasetV2): 442 _tf_dataset_for_stmt( 443 iter_, extra_test, body, get_state, set_state, symbol_names, opts) 444 445 elif isinstance(iter_, iterator_ops.OwnedIterator): 446 _tf_iterator_for_stmt( 447 iter_, extra_test, body, get_state, set_state, symbol_names, opts) 448 449 elif isinstance(iter_, ragged_tensor.RaggedTensor): 450 _tf_ragged_for_stmt( 451 iter_, extra_test, body, get_state, set_state, symbol_names, opts) 452 453 elif isinstance(iter_, distribute.Iterator): 454 _tf_iterator_for_stmt( 455 iter_, extra_test, body, get_state, set_state, symbol_names, opts) 456 457 elif isinstance(iter_, distribute.Iterable): 458 # TODO(b/162250181): Use _tf_iterator_for_stmt(iter(iter_)... 459 _tf_distributed_iterable_for_stmt( 460 iter_, extra_test, body, get_state, set_state, symbol_names, opts) 461 462 else: 463 _py_for_stmt(iter_, extra_test, body, None, None) 464 465 466def _py_for_stmt(iter_, extra_test, body, get_state, set_state): 467 """Overload of for_stmt that executes a Python for loop.""" 468 del get_state, set_state 469 470 if __debug__: 471 checker = _PythonLoopChecker() 472 before_iteration = checker.before_iteration 473 after_iteration = checker.after_iteration 474 before_iteration() 475 476 original_body = body 477 def protected_body(protected_iter): 478 original_body(protected_iter) 479 after_iteration() 480 before_iteration() 481 body = protected_body 482 483 if extra_test is not None: 484 def guarded_extra_test(): 485 extra_test_result = extra_test() 486 try: 487 # Note: Using try/except and not tensor_util.is_tf_type to avoid 488 # performance degradation. 489 return bool(extra_test_result) 490 except errors_impl.OperatorNotAllowedInGraphError as e: 491 ag_logging.log( 492 1, 493 'Caught error while evaluating loop stop condition', 494 exc_info=True) 495 # TODO(mdan): We can pass the location of extra_test and show it here. 496 raise NotImplementedError( 497 'break and return statements which depend on a TF condition are not' 498 ' supported in Python for loops. Did you intend to make it a TF' 499 ' loop?\nSee ' 500 'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 501 'python/autograph/g3doc/reference/limitations.md' 502 '#consistency-of-control-flow-types for more info.') from e 503 504 if guarded_extra_test(): 505 for target in iter_: 506 body(target) 507 if not guarded_extra_test(): 508 break 509 510 else: 511 for target in iter_: 512 body(target) 513 514 515def _add_max_iterations_hint(opts, n): 516 # TODO(b/159186914): Remove the safeguard, and always set maximum_iterations. 517 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 518 opts['maximum_iterations'] = n 519 520 521def _known_len_tf_for_stmt( 522 iter_, extra_test, body, get_state, set_state, symbol_names, opts): 523 """Overload of for_stmt that iterates over TF entities that admit a length.""" 524 n = py_builtins.len_(iter_) 525 526 # TODO(b/117628877): Revisit performance once XLA has the necessary support. 527 # Note: using a TensorArray creates an extra copy, but can calculate 528 # gradients more efficiently than StridedSlice. 529 ta = tensor_array_ops.TensorArray(iter_.dtype, size=n) 530 iter_ = ta.unstack(iter_) 531 532 iterate_index = 0 533 534 def aug_get_state(): 535 return (iterate_index,) + get_state() 536 537 def aug_set_state(aug_loop_vars): 538 nonlocal iterate_index 539 # TODO(b/171479293): Drop the lint override. 540 iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable 541 # The iteration index is not "output" by the for loop. If the iterate 542 # is used outside the loop, it will appear in the loop vars separately. 543 set_state(loop_vars) 544 545 def aug_body(): 546 nonlocal iterate_index 547 body(iter_.read(iterate_index)) 548 iterate_index += 1 549 550 def aug_test(): 551 main_test = iterate_index < n 552 if extra_test is not None: 553 return control_flow_ops.cond(main_test, extra_test, lambda: False) 554 return main_test 555 556 _add_max_iterations_hint(opts, n) 557 558 _tf_while_stmt( 559 aug_test, 560 aug_body, 561 aug_get_state, 562 aug_set_state, 563 ('<internal iterate>',) + symbol_names, 564 opts, 565 ) 566 567 568def _tf_ragged_for_stmt( 569 iter_, extra_test, body, get_state, set_state, symbol_names, opts): 570 """Overload of for_stmt that iterates over TF ragged tensors.""" 571 init_vars = get_state() 572 _verify_loop_init_vars(init_vars, symbol_names) 573 574 # TODO(mdan): Move this into len()? Requires eager support. 575 if iter_.shape and iter_.shape[0] is not None: 576 n = iter_.shape[0] 577 else: 578 n = iter_.row_lengths()[0] 579 580 iterate_index = 0 581 582 def aug_get_state(): 583 return (iterate_index,) + get_state() 584 585 def aug_set_state(aug_loop_vars): 586 nonlocal iterate_index 587 # TODO(b/171479293): Drop the lint override. 588 iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable 589 # The iteration index is not "output" by the for loop. If the iterate 590 # is used outside the loop, it will appear in the loop vars separately. 591 set_state(loop_vars) 592 593 def aug_body(): 594 nonlocal iterate_index 595 body(iter_[iterate_index]) 596 iterate_index += 1 597 598 def aug_test(): 599 main_test = iterate_index < n 600 if extra_test is not None: 601 return control_flow_ops.cond(main_test, extra_test, lambda: False) 602 return main_test 603 604 _add_max_iterations_hint(opts, n) 605 606 _tf_while_stmt( 607 aug_test, 608 aug_body, 609 aug_get_state, 610 aug_set_state, 611 ('<internal iterate>',) + symbol_names, 612 opts) 613 614 615def _tf_range_for_stmt( 616 iter_, extra_test, body, get_state, set_state, symbol_names, opts): 617 """Overload of for_stmt that iterates over a TF range (and elides it).""" 618 start, limit, delta = iter_.op.inputs 619 620 iterate = start 621 622 def _value_or(name, var, default): 623 if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)): 624 return default 625 return var 626 627 def aug_get_state(): 628 state_vars = get_state() 629 state_vars = tuple( 630 _value_or(name, var, iterate) 631 for name, var in zip(symbol_names, state_vars)) 632 return (iterate,) + state_vars 633 634 def aug_set_state(aug_loop_vars): 635 nonlocal iterate 636 # TODO(b/171479293): Drop the lint override. 637 iterate, *loop_vars = aug_loop_vars # pylint:disable=unused-variable 638 # The iteration index is not "output" by the for loop. If the iterate 639 # is used outside the loop, it will appear in the loop vars separately. 640 set_state(loop_vars) 641 642 def aug_body(): 643 nonlocal iterate 644 body(iterate) 645 iterate += delta 646 647 def aug_test(): 648 # TODO(b/159713842): Remove once constant folding works. 649 const_delta = tensor_util.constant_value(delta) 650 if const_delta is not None: 651 if const_delta >= 0: 652 main_test = iterate < limit 653 else: 654 main_test = iterate > limit 655 else: 656 main_test = math_ops.logical_or( 657 math_ops.logical_and(delta >= 0, iterate < limit), 658 math_ops.logical_and(delta < 0, iterate > limit)) 659 660 if extra_test is not None: 661 main_test = control_flow_ops.cond(main_test, extra_test, lambda: False) 662 return main_test 663 664 _add_max_iterations_hint( 665 opts, 666 math_ops.cast(misc.get_range_len(start, limit, delta), dtypes.int32)) 667 668 _tf_while_stmt( 669 aug_test, 670 aug_body, 671 aug_get_state, 672 aug_set_state, 673 ('<internal iterate>',) + symbol_names, 674 opts) 675 676 677def _tf_iterator_for_stmt( 678 iter_, extra_test, body, get_state, set_state, symbol_names, opts): 679 """Overload of for_stmt that iterates over TF Iterators. See for_loop.""" 680 symbol_names = ('<internal has_next>',) + symbol_names 681 has_next = True 682 683 def aug_get_state(): 684 return (has_next,) + get_state() 685 686 def aug_set_state(aug_loop_vars): 687 nonlocal has_next 688 # TODO(b/171479293): Drop the lint override. 689 has_next, *loop_vars = aug_loop_vars # pylint:disable=unused-variable 690 set_state(loop_vars) 691 692 init_vars = aug_get_state() 693 _verify_loop_init_vars(init_vars, symbol_names) 694 695 def aug_body(): 696 """Main body passed to _tf_while_stmt.""" 697 nonlocal has_next 698 opt_iterate = iter_.get_next_as_optional() 699 has_next = opt_iterate.has_value() 700 loop_vars = aug_get_state() # updated by set_state() in _tf_while_loop. 701 702 def main_path(): 703 body(opt_iterate.get_value()) 704 new_loop_vars = aug_get_state() 705 # Note: this verification duplicates the one performed in tf_while_stmt, 706 # but needs to be done earlier to prevent the tf.cond from blowing up 707 # first. 708 _verify_tf_loop_vars( 709 init_vars, loop_vars, new_loop_vars, symbol_names, opts) 710 return new_loop_vars 711 712 def noop_path(): 713 return loop_vars 714 715 # TODO(mdan): If tf.while_loop supported Optional, this could be avoided. 716 # Calling set_state so that get_state() _tf_while_loop sees the conditional 717 # tensors. 718 aug_set_state( 719 control_flow_ops.cond(has_next, main_path, noop_path)) 720 721 def aug_test(): 722 # This value takes a complicated path to get here: 723 # prev_iteration_body -> get_state -> tf.while_loop (as loop var) 724 # -> current_iteration_body -> set_state -> has_next 725 main_test = has_next 726 if extra_test is not None: 727 return control_flow_ops.cond(main_test, extra_test, lambda: False) 728 return main_test 729 730 _tf_while_stmt( 731 aug_test, 732 aug_body, 733 aug_get_state, 734 aug_set_state, 735 symbol_names, 736 opts) 737 738 739def _general_purpose_scan(ds, init_state, body): 740 """Variant of Dataset.scan with semantics of general-purpose computation.""" 741 # Datasets are typically intended for data preprocessing. However, in 742 # autograph loops they usually appear as general-purpose computations (for 743 # example, a custom training loop). These two use cases require significantly 744 # different optimization policies, the most important of which is the device 745 # placement. The flag override for use_default_device below instructs the 746 # runtime to treat the computation as general-purpose, rather than data 747 # preprocessing. 748 # TODO(mdan): s/use_default_device/specialize_for_input_pipeline. 749 # TODO(mdan): Don't use private symbols. 750 # pylint:disable=protected-access 751 return dataset_ops._ScanDataset( 752 ds, init_state, body, use_default_device=False) 753 754 755def _tf_dataset_for_stmt( 756 ds, extra_test, body, get_state, set_state, symbol_names, opts): 757 """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" 758 # Note: This is easier to follow with the insight that the computations in 759 # a dataset pipeline are transposed (aka fused). 760 # For example, given a pipeline input -> scan -> take_while -> reduce, 761 # and a dataset with input [1, 2, 3], the computations occur in the following 762 # order: 763 # reduce(take_while(scan(1))) 764 # reduce(take_while(scan(2))) 765 # reduce(take_while(scan(3))) 766 767 init_vars = get_state() 768 _verify_loop_init_vars(init_vars, symbol_names) 769 770 # Workaround for Dataset.reduce not allowing empty state tensors - create 771 # a dummy state variable that remains unused. 772 # TODO(mdan): reduce should allow and match empty structures. 773 if not init_vars: 774 init_vars = (constant_op.constant(0),) 775 symbol_names = ('<internal dummy>',) 776 777 def dummy_set_state(unused_dummy): 778 pass 779 780 def dummy_get_state(): 781 return (constant_op.constant(0),) 782 783 get_state, set_state = dummy_get_state, dummy_set_state 784 785 def scan_body(scan_state, scan_inputs): 786 """Main body of the Dataset.scan.""" 787 loop_vars, iterate = scan_state, scan_inputs 788 set_state(loop_vars) 789 790 def main_path(): 791 body(iterate) 792 new_loop_vars = get_state() 793 _verify_tf_loop_vars( 794 init_vars, loop_vars, new_loop_vars, symbol_names, opts, 795 check_shapes=False) 796 return new_loop_vars 797 798 if extra_test is not None: 799 extra_cond = extra_test() 800 new_loop_vars = control_flow_ops.cond( 801 extra_cond, main_path, lambda: loop_vars) 802 else: 803 # TODO(mdan): the optimizer should be able to remove an invariant cond? 804 extra_cond = (constant_op.constant(True),) # dummy value, unused 805 new_loop_vars = main_path() 806 807 scan_outputs = new_loop_vars, extra_cond 808 new_scan_state = new_loop_vars 809 return new_scan_state, scan_outputs 810 811 def take_while_predicate(unused_loop_vars, extra_cond): 812 return extra_cond 813 814 def reduce_body(unused_reduce_state, scan_outputs): 815 output_loop_vars, unused_extra_cond = scan_outputs 816 new_reduce_state = output_loop_vars 817 return new_reduce_state 818 819 ds = _general_purpose_scan(ds, init_vars, scan_body) 820 if extra_test is not None: 821 ds = ds.apply(take_while_ops.take_while(take_while_predicate)) 822 final_loop_vars = ds.reduce(init_vars, reduce_body) 823 set_state(final_loop_vars) 824 825 826def _tf_distributed_iterable_for_stmt( 827 iter_, extra_test, body, get_state, set_state, symbol_names, opts): 828 """Overload of for_stmt that iterates over TF distributed datasets.""" 829 830 if extra_test is not None: 831 raise NotImplementedError( 832 'break and return statements are not yet supported in ' 833 'for ... in distributed input loops.') 834 835 init_vars = get_state() 836 _verify_loop_init_vars(init_vars, symbol_names) 837 838 if 'shape_invariants' in opts: 839 opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list( 840 opts['shape_invariants'], init_vars) 841 842 def reduce_body(loop_vars, iterate): 843 set_state(loop_vars) 844 body(iterate) 845 new_loop_vars = get_state() 846 _verify_tf_loop_vars( 847 init_vars, loop_vars, new_loop_vars, symbol_names, opts) 848 return new_loop_vars 849 850 set_state(iter_.reduce(init_vars, reduce_body)) 851 852 853def while_stmt(test, body, get_state, set_state, symbol_names, opts): 854 """Functional form of a while statement. 855 856 The loop operates on a so-called state, which includes all symbols that are 857 variant across loop iterations. In what follows we refer to state as either 858 a tuple of entities that represent an actual state, or a list of arguments 859 of the corresponding types. 860 861 The inputs and outputs of the callables representing the loop blocks are not 862 explicit - instead, these functions must use nonlocal/global for side effects. 863 The inputs and outputs are instead controlled by the set_state/get_state 864 functions. 865 866 Args: 867 test: Callable with boolean return type. The loop condition. 868 body: Callable representing the actual loop body. 869 get_state: Additional callable which can capture additional state (such as 870 the values of composite symbols). This is only useful when staging the 871 loop. 872 set_state: Additional callable which save values captured by get_state back 873 into the Python environment. This is only useful when staging the loop. 874 symbol_names: Tuple containing the names of all loop variables. 875 opts: Optional dict of extra loop parameters. 876 877 Returns: 878 Tuple containing the final state. 879 """ 880 881 # Evaluate the initial test once in order to do the dispatch. The evaluation 882 # is isolated to minimize unwanted side effects. 883 # TODO(mdan): Do a full iteration - some state types might lower to Tensor. 884 with func_graph.FuncGraph('tmp').as_default(): 885 init_test = test() 886 887 # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine 888 # with the re-evaluation of `test` that `_tf_while_stmt` will make. 889 if tensors.is_dense_tensor(init_test): 890 _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts) 891 return 892 893 # Normal Python: We already consumed one evaluation of `test`; consistently, 894 # unroll one iteration before dispatching to a normal loop. 895 # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt? 896 if not init_test: 897 return 898 body() 899 900 _py_while_stmt(test, body, get_state, set_state, opts) 901 902 903class _PythonLoopChecker(object): 904 """Verifies Python loops for TF-specific limits.""" 905 906 __slots__ = ( 907 'iterations', 908 'check_inefficient_unroll', 909 'check_op_count_after_iteration', 910 'ops_before_iteration', 911 ) 912 913 def __init__(self): 914 self.iterations = 1 915 self.check_inefficient_unroll = WARN_INEFFICIENT_UNROLL 916 917 # Triggered when we decided to test the op counts. 918 self.check_op_count_after_iteration = False 919 920 def _get_ops(self): 921 return ops.get_default_graph().get_operations() 922 923 def _check_unroll_limits(self): 924 if self.iterations > PYTHON_MAX_ITERATIONS: 925 raise ValueError('iteration limit exceeded') 926 927 def _stop_checking_inefficient_unroll(self): 928 self.check_inefficient_unroll = False 929 self.check_op_count_after_iteration = False 930 self.ops_before_iteration = None 931 932 def _verify_inefficient_unroll(self): 933 """Checks for possibly-inefficient creation of ops in a Python loop.""" 934 assert self.ops_before_iteration is not None 935 ops_after_iteration = self._get_ops() 936 new_ops = tuple( 937 op for op in ops_after_iteration if op not in self.ops_before_iteration) 938 939 if len(new_ops) < INEFFICIENT_UNROLL_MIN_OPS: 940 return False 941 942 ag_logging.warning( 943 'Large unrolled loop detected. Did you mean to use a TF loop?' 944 ' The following ops were created after iteration %s: %s' 945 '\nSee' 946 ' https://github.com/tensorflow/tensorflow/blob/master/' 947 'tensorflow/python/autograph/g3doc/reference/common_errors.md' 948 '#warning-large-unrolled-loop-detected' 949 '\n' 950 'Location:' 951 '\n%s' 952 '', self.iterations, new_ops, '\n'.join(traceback.format_stack())) 953 return True 954 955 def before_iteration(self): 956 """Called before each iteration in a Python loop.""" 957 if (self.check_inefficient_unroll and 958 self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS): 959 self.ops_before_iteration = self._get_ops() 960 self.check_op_count_after_iteration = True 961 962 def after_iteration(self): 963 """Called after each iteration in a Python loop.""" 964 self.iterations += 1 965 966 self._check_unroll_limits() 967 968 if self.check_op_count_after_iteration: 969 did_warn = self._verify_inefficient_unroll() 970 if did_warn: 971 self._stop_checking_inefficient_unroll() # Only warn once. 972 elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3: 973 # Once deciding to check the op counts, only do it for a few iterations. 974 self._stop_checking_inefficient_unroll() 975 976 977def _py_while_stmt(test, body, get_state, set_state, opts): 978 """Overload of while_stmt that executes a Python while loop.""" 979 del opts, get_state, set_state 980 981 if __debug__: 982 checker = _PythonLoopChecker() 983 before_iteration = checker.before_iteration 984 after_iteration = checker.after_iteration 985 before_iteration() 986 987 original_body = body 988 def protected_body(): 989 original_body() 990 after_iteration() 991 before_iteration() 992 body = protected_body 993 994 def guarded_test(): 995 test_result = test() 996 try: 997 # Note: Using try/except and not tensor_util.is_tf_type to avoid 998 # performance degradation. 999 return bool(test_result) 1000 except errors_impl.OperatorNotAllowedInGraphError as e: 1001 ag_logging.log( 1002 1, 1003 'Caught error while evaluating while loop condition', 1004 exc_info=True) 1005 # TODO(mdan): distinguish beteen these two cases. 1006 raise NotImplementedError( 1007 'The condition of while loop started as non-Tensor, then changed to' 1008 ' Tensor. This may happen either because variables changed type, or' 1009 ' when a break or return statement inside the loop depends on a' 1010 ' Tensor condition. In both cases, changing to a TF loop should' 1011 ' remove the error.\nSee ' 1012 'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 1013 'python/autograph/g3doc/reference/limitations.md' 1014 '#consistency-of-control-flow-types for more info.') from e 1015 while guarded_test(): 1016 body() 1017 1018 1019def _shape_invariants_mapping_to_positional_list(mapping, keys): 1020 # The keys are not expected to be hashable. 1021 mapping = {id(k): (k, v) for k, v in mapping} 1022 result = [] 1023 for k in keys: 1024 map_key, map_val = mapping.get(id(k), (None, None)) 1025 result.append( 1026 map_val if map_key is k else nest.map_structure(lambda _: None, k)) 1027 return tuple(result) 1028 1029 1030# Textual description of what a legal TF loop variable is. This description 1031# summarizes types that _placeholder_value below can handle. Keep the two 1032# together and in sync. 1033LEGAL_LOOP_TYPES = 'Tensor, int, float, bool or a list, tuple or dict thereof' 1034 1035 1036def _placeholder_value(like, shape_invariant, original=None): 1037 """Constructs a (dummy) placeholder value for a loop-initialized variable. 1038 1039 Args: 1040 like: Any object. The value created by the first iteration of the loop. If a 1041 Python scalar, the placeholder will be the zero value of that type. If a 1042 Tensor, the placeholder will be a zero tensor of matching shape and dtype. 1043 If a list, dict or tuple, the placeholder will be an identical structure 1044 of placeholders. 1045 shape_invariant: The shape invariant specified by the user (or None, if 1046 nothing was specified) for the respective variable. 1047 original: Any object. The value of the variable prior to entering the loop. 1048 Typically, this is one of the special "Undefined" value, because that's 1049 when a placeholder is needed. 1050 1051 Returns: 1052 Either a zero value of structure, shape and dtype mathing 'like', or 1053 'original', if no such zero value could be created. 1054 """ 1055 if like is None: 1056 return original, None 1057 1058 elif isinstance(like, (variables.Undefined, variables.UndefinedReturnValue)): 1059 return original, None 1060 1061 elif isinstance(like, (int, float, bool)): 1062 return type(like)(0), None 1063 1064 elif tensor_util.is_tf_type(like): 1065 1066 like_shape = shape_invariant if shape_invariant is not None else like.shape 1067 if like_shape is None or like_shape.rank is None: 1068 return array_ops.zeros((), like.dtype), like_shape 1069 1070 # If the shape contains dynamic values, set the corresponding starting 1071 # dimension to either zero or what the shape invariant specified. 1072 placeholder_shape = [] 1073 has_dynamic_dims = False 1074 for s, i in zip(like.shape, like_shape): 1075 if i is None: 1076 like_dim = 0 1077 elif isinstance(i, tensor_shape.Dimension): 1078 if i.value is None: 1079 like_dim = 0 1080 else: 1081 like_dim = i.value 1082 else: 1083 like_dim = i 1084 1085 if s is None: 1086 placeholder_shape.append(like_dim) 1087 has_dynamic_dims = True 1088 elif isinstance(s, tensor_shape.Dimension): 1089 if s.value is None: 1090 placeholder_shape.append(like_dim) 1091 has_dynamic_dims = True 1092 else: 1093 placeholder_shape.append(s.value) 1094 else: 1095 placeholder_shape.append(s) 1096 1097 if has_dynamic_dims: 1098 invariant = like_shape 1099 else: 1100 invariant = None 1101 1102 return array_ops.zeros(placeholder_shape, like.dtype), invariant 1103 1104 elif isinstance(like, (list, tuple, dict)): 1105 if shape_invariant is None: 1106 zipped = nest.map_structure(lambda v: _placeholder_value(v, None), 1107 nest.flatten(like)) 1108 else: 1109 zipped = nest.map_structure(_placeholder_value, nest.flatten(like), 1110 nest.flatten(shape_invariant)) 1111 vals, invars = zip(*zipped) 1112 return (nest.pack_sequence_as(like, 1113 vals), nest.pack_sequence_as(like, invars)) 1114 1115 # This is to be caught by _try_handling_undefineds, to give more context. 1116 raise TypeError( 1117 "Found an unsupported type '{}' while creating placeholder for {}." 1118 ' Supported types include Tensor, int, float, bool, list, tuple or dict.' 1119 .format(type(like).__name__, like)) 1120 1121 1122def _try_handling_undefineds(body, get_state, set_state, init_vars, nulls, 1123 shape_invariants, symbol_names): 1124 """Makes a best-effort attempt to substitute undefineds with placeholders. 1125 1126 Note: this substitution requires two things to happen: 1127 1. the types of loop variables could be inferred (usually by staging one 1128 iteration) 1129 2. these types could be replaced by placeholders (e.g. zero values, for 1130 tensors. 1131 1132 Args: 1133 body: a function representing the loop body. See while_stmt. 1134 get_state: state getter for the loop statement. See while_stmt. 1135 set_state: state getter for the loop statement. See while_stmt. 1136 init_vars: loop variables before entering the loop. See while_stmt. 1137 nulls: list of boolean flags indicating whether the corresponding loop var 1138 is None or undefined. 1139 shape_invariants: user-specified shape invariant for each loop variable. 1140 symbol_names: list of loop variable names. See while_stmt. 1141 1142 Returns: 1143 A tuple (success, new_init_vars, extra_shape_invariants, failure_message): 1144 * success is a boolean flag indicating 1145 whether types could be successfully inferred (step 1 above) 1146 * new_init_vars contains the loop vars, with None or undefined values 1147 replaced by default values, where possible (step 2 above) 1148 * extra_shape_invariants contains shape invariants that would be needed 1149 by while_stmt, for instance if the placeholder values had a shape 1150 different from the corresponding loop outputs 1151 """ 1152 state_modified = False 1153 first_iter_vars = None 1154 failure_message = None 1155 1156 try: 1157 # Stage an iteration of the loop body in a temporary graph. 1158 with func_graph.FuncGraph('tmp').as_default(): 1159 # This call to set_state helps report nicer error messages when symbols 1160 # are inconsistently used. 1161 # Another complication is that non_tensor values will be autocast to 1162 # Tensor by while_loop, and their static value lost. So we need to account 1163 # that here. 1164 def autocast_to_tensor(v): 1165 if isinstance( 1166 v, (int, float, bool, str, list, tuple, np.ndarray, np.generic)): 1167 init_val = ops.convert_to_tensor_v2(v) 1168 return array_ops.placeholder(init_val.dtype, init_val.shape) 1169 return v 1170 autocast_init_vars = nest.map_structure(autocast_to_tensor, init_vars) 1171 set_state(autocast_init_vars) 1172 state_modified = True 1173 1174 body() 1175 first_iter_vars = get_state() 1176 1177 # Note: the actual placeholder value doesn't matter, because as the 1178 # staging proved, it will be replaced by an actual value before being 1179 # read. 1180 inits_and_invariants = tuple( 1181 (_placeholder_value(iv, i, v) if n else (v, None)) 1182 for v, n, iv, i in zip(init_vars, nulls, first_iter_vars, 1183 shape_invariants)) 1184 init_vars, extra_shape_invariants = zip(*inits_and_invariants) 1185 success = True 1186 1187 except (UnboundLocalError, TypeError, ValueError, KeyError): 1188 ag_logging.log(1, 'Caught error while staging loop body', exc_info=True) 1189 # Fall back to the old functionality. It will likely result in an input 1190 # validation failure. 1191 exc = sys.exc_info() 1192 failure_message = ( 1193 'Note: AutoGraph tried to define it automatically, but ran into a' 1194 ' {}: {}'.format(exc[0].__name__, exc[1])) 1195 1196 finally: 1197 if state_modified: 1198 set_state(init_vars) 1199 1200 # This check runs regardless, in case we captured non-Tensor inputs. 1201 _verify_loop_init_vars( 1202 init_vars, symbol_names, first_iter_vars, extra_message=failure_message) 1203 1204 return success, init_vars, extra_shape_invariants 1205 1206 1207def _runtime_zero_iterations_errmsg(symbol_names, nulls, init_vars): 1208 """Creates an error message asking for the loop to iterate at least once.""" 1209 var_names = [] 1210 for sn, n, v in zip(symbol_names, nulls, init_vars): 1211 if not n: 1212 continue 1213 if isinstance(v, variables.UndefinedReturnValue): 1214 var_names.append('the function return value') 1215 else: 1216 var_names.append(sn) 1217 var_names = ', '.join(var_names) 1218 return 'loop must iterate at least once to initialize {}'.format(var_names) 1219 1220 1221def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts): 1222 """Overload of while_stmt that stages a TF while_stmt.""" 1223 init_vars = get_state() 1224 orig_init_vars = init_vars 1225 1226 nulls = tuple(_is_none_or_undef(v) for v in init_vars) 1227 if any(nulls): 1228 shape_invars_by_init_vals = { 1229 id(v): i for v, i in opts.get('shape_invariants', ()) 1230 } 1231 shape_invariants = tuple( 1232 shape_invars_by_init_vals.get(id(v), None) for v in orig_init_vars) 1233 (require_one_iteration, init_vars, 1234 extra_shape_invariants) = _try_handling_undefineds(body, get_state, 1235 set_state, init_vars, 1236 nulls, shape_invariants, 1237 symbol_names) 1238 else: 1239 require_one_iteration = False 1240 1241 if require_one_iteration: 1242 merged_shape_invariants = dict(shape_invars_by_init_vals) 1243 # This has two roles: 1244 # 1. Shape invariants are remapped from the old init vars to the new ones. 1245 # 2. Any new shape invariants created by the init vars are kept, but only 1246 # if the user didn't already specified some. 1247 for v, nv, ni in zip(orig_init_vars, init_vars, extra_shape_invariants): 1248 merged_invariant = merged_shape_invariants.get(id(v), ni) 1249 if merged_invariant is not None: 1250 merged_shape_invariants[id(nv)] = merged_invariant 1251 merged_shape_invariants = tuple((nv, merged_shape_invariants[id(nv)]) 1252 for nv in init_vars 1253 if id(nv) in merged_shape_invariants) 1254 if merged_shape_invariants: 1255 opts = dict(**opts) 1256 opts['shape_invariants'] = merged_shape_invariants 1257 1258 def aug_test(*loop_vars): 1259 if require_one_iteration: 1260 loop_vars = loop_vars[1:] 1261 1262 set_state(loop_vars) 1263 return _verify_tf_condition(test(), 'while loop') 1264 1265 def aug_body(*loop_vars): 1266 if require_one_iteration: 1267 loop_vars = loop_vars[1:] 1268 1269 set_state(loop_vars) 1270 body() 1271 new_loop_vars = get_state() 1272 _verify_tf_loop_vars( 1273 init_vars, loop_vars, new_loop_vars, symbol_names, opts) 1274 1275 if require_one_iteration: 1276 new_loop_vars = (True,) + new_loop_vars 1277 1278 return new_loop_vars 1279 1280 if 'shape_invariants' in opts: 1281 opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list( 1282 opts['shape_invariants'], init_vars) 1283 1284 while_loop_opts = dict(opts) 1285 while_loop_opts.pop('iterate_names', None) 1286 1287 # Non-v2 while_loop unpacks the results when there is only one return value. 1288 # This enforces consistency across versions. 1289 while_loop_opts['return_same_structure'] = True 1290 1291 if require_one_iteration: 1292 aug_init_vars = (False,) + init_vars 1293 if 'shape_invariants' in while_loop_opts: 1294 while_loop_opts['shape_invariants'] = ( 1295 (None,) + while_loop_opts['shape_invariants']) 1296 else: 1297 aug_init_vars = init_vars 1298 1299 final_loop_vars = control_flow_ops.while_loop( 1300 aug_test, aug_body, aug_init_vars, **while_loop_opts) 1301 1302 if require_one_iteration: 1303 with ops.control_dependencies([ 1304 control_flow_ops.Assert(final_loop_vars[0], [ 1305 _runtime_zero_iterations_errmsg(symbol_names, nulls, orig_init_vars) 1306 ]) 1307 ]): 1308 final_loop_vars = nest.map_structure( 1309 lambda v: (array_ops.identity(v) if tensor_util.is_tf_type(v) else v), 1310 final_loop_vars[1:], 1311 ) 1312 1313 set_state(final_loop_vars) 1314 1315 1316def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts): 1317 """Functional form of an if statement. 1318 1319 The conditional operates on a state, which includes all symbols whose values 1320 are a function of the branch taken. 1321 1322 For example, given the code below that calculates the abs function: 1323 1324 ``` 1325 x = 1 1326 if x > 0: 1327 x = -x 1328 ``` 1329 1330 The state is represented by the variable `x`. The `body, `orelse` and 1331 `set_state` functions must bind to the original `x` symbol, using `nonlocal`. 1332 1333 The inputs and outputs of the callables representing the loop blocks are not 1334 explicit - instead, these functions must use nonlocal/global for side effects. 1335 The inputs and outputs are instead controlled by the set_state/get_state 1336 functions. 1337 1338 Args: 1339 cond: Boolean. 1340 body: Callable representing the main block of the conditional. 1341 orelse: Callable representing the else block of the conditional. 1342 get_state: Function that returns a tuple containing the values of all 1343 composite symbols modified within the conditional. This allows access to 1344 state that branches may mutate through side effects. This function is not 1345 needed and should not be called when dispatching to code matching Python's 1346 default semantics. This is useful for checkpointing to avoid unintended 1347 side-effects when staging requires evaluating all code-paths. 1348 set_state: Function to set the values of all composite symbols modified 1349 within the conditional. This is the complement to get_state, used to 1350 restore checkpointed values. The single argument a tuple containing values 1351 for each composite symbol that may be modified in a branch of the 1352 conditional. The is usually the result of a call to get_state. 1353 symbol_names: Tuple containing basic loop var names. 1354 nouts: Number of variables output by the statement. Vars which are not 1355 outputs will not be passed through staged control flow such as tf.cond. 1356 This includes variables that are defined before the conditional, but are 1357 not used after it. 1358 """ 1359 # Note: tf.cond doesn't support SparseTensor. 1360 if tensors.is_dense_tensor(cond): 1361 _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts) 1362 else: 1363 _py_if_stmt(cond, body, orelse) 1364 1365 1366def _tf_if_stmt( 1367 cond, body, orelse, get_state, set_state, symbol_names, nouts): 1368 """Overload of if_stmt that stages a TF cond.""" 1369 cond = _verify_tf_condition(cond, 'if statement') 1370 1371 if not nouts: 1372 prev_get_state, prev_set_state = get_state, set_state 1373 # Control flow V1 wants at least one output. 1374 get_state = lambda: (0,) + prev_get_state() 1375 set_state = lambda v: prev_set_state(v[1:]) 1376 symbol_names += ('<unused dummy>',) 1377 nouts = 1 1378 1379 init_vars = get_state() 1380 1381 # TODO(mdan): Use nonlocal once we no longer need to support py2. 1382 new_body_vars_ = [None] 1383 new_orelse_vars_ = [None] 1384 1385 def aug_body(): 1386 set_state(init_vars) 1387 body() 1388 new_body_vars = get_state() 1389 new_body_vars = new_body_vars[:nouts] 1390 new_body_vars_[0] = new_body_vars 1391 _verify_tf_cond_branch_vars(new_body_vars, symbol_names, 'main') 1392 if new_orelse_vars_[0] is not None: 1393 _verify_tf_cond_vars(new_body_vars, new_orelse_vars_[0], symbol_names) 1394 return new_body_vars 1395 1396 def aug_orelse(): 1397 set_state(init_vars) 1398 orelse() 1399 new_orelse_vars = get_state() 1400 new_orelse_vars = new_orelse_vars[:nouts] 1401 new_orelse_vars_[0] = new_orelse_vars 1402 _verify_tf_cond_branch_vars(new_orelse_vars, symbol_names, 'else') 1403 if new_body_vars_[0] is not None: 1404 _verify_tf_cond_vars(new_body_vars_[0], new_orelse_vars, symbol_names) 1405 return new_orelse_vars 1406 1407 final_cond_vars = control_flow_ops.cond( 1408 cond, aug_body, aug_orelse, strict=True) 1409 final_cond_vars = final_cond_vars + init_vars[nouts:] 1410 1411 set_state(final_cond_vars) 1412 1413 1414def _py_if_stmt(cond, body, orelse): 1415 """Overload of if_stmt that executes a Python if statement.""" 1416 return body() if cond else orelse() 1417