1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Control flow statements: loops, conditionals, etc. 16 17Python 2 compatibility version. Not maintained. 18 19Note: most of these operators accept pairs of get_state/set_state functions, to 20capture mutations that the corresponding code blocks might make. These 21mutations only need to be captured when staging the control flow, and they just 22work when reverting to Python behavior. 23 24__Examples__ 25 26``` 27while cond: 28 self.x += i 29``` 30 31When the functionalized version is executed as a Python loop, it just works: 32 33``` 34def loop_body(): 35 self.x += i # works as expected for Python loops 36``` 37 38But it won't work for TF loops: 39 40``` 41def loop_body(): 42 self.x += i # self.x has the wrong value! 43``` 44 45get_state/set_state allow piping the mutations through the loop variables as 46well, in effect changing the loop body: 47 48``` 49def loop_body(self_x): 50 self.x = self_x # self.x now has the proper value 51 self.x += i # the original block 52 self_x = self.x # write self.x back into the loop vars 53 return self_x 54 55self_x = tf.while_loop(...) 56self.x = self_x # the result is not properly captured 57``` 58""" 59 60from __future__ import absolute_import 61from __future__ import division 62from __future__ import print_function 63 64import functools 65 66import numpy as np 67 68from tensorflow.python.autograph.operators import py_builtins 69from tensorflow.python.autograph.operators import variables 70from tensorflow.python.autograph.utils import ag_logging 71from tensorflow.python.autograph.utils import misc 72from tensorflow.python.autograph.utils import tensors 73from tensorflow.python.data.experimental.ops import take_while_ops 74from tensorflow.python.data.ops import dataset_ops 75from tensorflow.python.data.ops import iterator_ops 76from tensorflow.python.framework import constant_op 77from tensorflow.python.framework import dtypes 78from tensorflow.python.framework import func_graph 79from tensorflow.python.framework import ops 80from tensorflow.python.framework import tensor_util 81from tensorflow.python.ops import array_ops 82from tensorflow.python.ops import control_flow_ops 83from tensorflow.python.ops import math_ops 84from tensorflow.python.ops import tensor_array_ops 85from tensorflow.python.ops.ragged import ragged_tensor 86from tensorflow.python.util import lazy_loader 87from tensorflow.python.util import nest 88 89 90# TODO(b/145618471): Remove this dependency. 91# Lazy import to work around circular dependencies 92input_lib = lazy_loader.LazyLoader( 93 'input_lib', globals(), 94 'tensorflow.python.distribute.input_lib') 95 96LIMIT_PYTHON_ITERATIONS = True 97PYTHON_MAX_ITERATIONS = 100000000 # Fails in about one minute for empty loops. 98WARN_INEFFICIENT_UNROLL = True 99INEFFICIENT_UNROLL_MIN_ITERATIONS = 3000 100INEFFICIENT_UNROLL_MIN_OPS = 1 101 102 103def _disallow_undefs_into_loop(*values): 104 """Ensures that all values in the state are defined when entering a loop.""" 105 undefined = [v for v in values if isinstance(v, variables.Undefined)] 106 if undefined: 107 raise ValueError( 108 '{} must be defined before the loop.'.format( 109 ','.join(s.symbol_name for s in undefined))) 110 for value in values: 111 if isinstance(value, variables.UndefinedReturnValue): 112 # Assumption: the loop will only capture the variable which tracks the 113 # return value if the loop contained a return statement. 114 # TODO(mdan): This should be checked at the place where return occurs. 115 raise ValueError( 116 'return statements are not supported within a TensorFlow loop.') 117 118 119def _is_subshape(left, right): 120 """Returns True if left shape is at least as specific as right shape.""" 121 # TODO(mdan): This code should be in TensorShape. 122 # Note: this is not the same as TensorShape.is_compatible_with, which is 123 # symmetric. 124 # This code also duplicates _ShapeLessThanOrEqual from control_flow_ops.py. 125 if right.dims is None: 126 return True 127 if left.ndims != right.ndims: 128 return False 129 for ldim, rdim in zip(left.dims, right.dims): 130 if rdim.value is not None and ldim.value != rdim.value: 131 return False 132 return True 133 134 135# TODO(mdan): Remove these verifications once TF ops can properly report names. 136def _verify_single_loop_var( 137 name, check_shape, init, entry, exit_, shape_invariant): 138 """Verifies whether the initial, entry and exit values are consistent.""" 139 if isinstance(init, (bool, int, float, str, np.ndarray)): 140 init = ops.convert_to_tensor_v2(init) 141 if isinstance(entry, (bool, int, float, str, np.ndarray)): 142 entry = ops.convert_to_tensor_v2(entry) 143 if isinstance(exit_, (bool, int, float, str)): 144 exit_ = ops.convert_to_tensor_v2(exit_) 145 146 if (not tensor_util.is_tf_type(entry) or 147 not tensor_util.is_tf_type(exit_)): 148 return 149 150 # TODO(mdan): Properly account for CompositeTensors. 151 if (not hasattr(entry, 'dtype') or 152 not hasattr(exit_, 'dtype')): 153 return 154 if (not hasattr(entry, 'shape') or 155 not hasattr(exit_, 'shape')): 156 return 157 158 if entry.dtype != exit_.dtype: 159 raise TypeError( 160 '"{}" has dtype {} before the loop, but dtype {} after one' 161 ' iteration. TensorFlow control flow requires it stays the' 162 ' same.'.format( 163 name, 164 entry.dtype.name, 165 exit_.dtype.name, 166 )) 167 if check_shape: 168 exit_shape = exit_.shape 169 if shape_invariant is None: 170 entry_shape = entry.shape 171 if not _is_subshape(exit_shape, entry_shape): 172 raise ValueError( 173 '"{}" has shape {} before the loop, but shape {} after one' 174 ' iteration. Use tf.autograph.experimental.set_loop_options to set' 175 ' shape invariants.'.format(name, entry_shape, exit_shape)) 176 else: 177 init_shape = init.shape 178 if not _is_subshape(init_shape, shape_invariant): 179 raise ValueError( 180 '"{}" has shape {} before the loop, which does not conform with' 181 ' the shape invariant {}.'.format(name, init_shape, 182 shape_invariant)) 183 if not _is_subshape(exit_shape, shape_invariant): 184 raise ValueError( 185 '"{}" has shape {} after the loop, which does not conform with' 186 ' the shape invariant {}.'.format( 187 name, exit_shape, shape_invariant)) 188 189 190def _verify_tf_loop_vars(init_vars, 191 iter_entry_vars, 192 iter_exit_vars, 193 symbol_names, 194 opts, 195 check_shapes=True): 196 """Verifies loop variables for consistency.""" 197 if check_shapes and 'shape_invariants' in opts: 198 shape_invariants = opts['shape_invariants'] 199 else: 200 shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars) 201 202 named_vars = zip(symbol_names, init_vars, iter_entry_vars, iter_exit_vars, 203 shape_invariants) 204 for name, init, entry, exit_, invariant in named_vars: 205 try: 206 nest.assert_same_structure(entry, exit_, expand_composites=True) 207 except (ValueError, TypeError) as e: 208 raise TypeError('"{}" does not have the same nested structure after one' 209 ' iteration.\n\n{}'.format(name, e)) 210 if invariant is not None: 211 try: 212 nest.assert_same_structure(init, invariant, expand_composites=False) 213 except (ValueError, TypeError) as e: 214 raise TypeError('"{}" does not have the same nested structure as its' 215 ' corresponding shape invariant.\n\n{}'.format(name, e)) 216 217 nest.map_structure( 218 functools.partial(_verify_single_loop_var, name, check_shapes), init, 219 entry, exit_, invariant) 220 221 222def _verify_single_cond_var(name, body_var, orelse_var): 223 """Verifies whether body_var and orelse_var are consistent.""" 224 if isinstance(body_var, (bool, int, float, str)): 225 body_var = ops.convert_to_tensor_v2(body_var) 226 227 if isinstance(orelse_var, (bool, int, float, str)): 228 orelse_var = ops.convert_to_tensor_v2(orelse_var) 229 230 if (not tensor_util.is_tf_type(body_var) or 231 not tensor_util.is_tf_type(orelse_var)): 232 return 233 234 # TODO(mdan): Properly account for CompositeTensors. 235 if (not hasattr(body_var, 'dtype') or 236 not hasattr(orelse_var, 'dtype')): 237 return 238 239 if body_var.dtype != orelse_var.dtype: 240 raise TypeError( 241 '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE' 242 ' branch. TensorFlow control flow requires that they are the' 243 ' same.'.format(name, body_var.dtype.name, 244 orelse_var.dtype.name)) 245 246 247def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names): 248 """Verifies variables manipulated by a conditional for consistency.""" 249 basic_body_vars, composite_body_vars = body_vars 250 basic_orelse_vars, composite_orelse_vars = orelse_vars 251 assert isinstance(composite_body_vars, tuple) 252 assert isinstance(composite_orelse_vars, tuple) 253 254 # TODO(kkb): Make this more consistent. 255 # The basic outputs should always be a tuple. 256 if not isinstance(basic_body_vars, tuple): 257 basic_body_vars = (basic_body_vars,) 258 if not isinstance(basic_orelse_vars, tuple): 259 basic_orelse_vars = (basic_orelse_vars,) 260 261 body_vars = basic_body_vars + composite_body_vars 262 orelse_vars = basic_orelse_vars + composite_orelse_vars 263 264 named_vars = zip(symbol_names, body_vars, orelse_vars) 265 for name, body_var, orelse_var in named_vars: 266 try: 267 nest.assert_same_structure( 268 body_var, orelse_var, expand_composites=True) 269 except (ValueError, TypeError) as e: 270 raise TypeError( 271 '"{}" does not have the same nested structure in the TRUE and FALSE' 272 ' branches.\n\n{}'.format(name, str(e))) 273 274 nest.map_structure( 275 functools.partial(_verify_single_cond_var, name), body_var, orelse_var) 276 277 278def for_stmt(iter_, 279 extra_test, 280 body, 281 get_state, 282 set_state, 283 init_vars, 284 basic_symbol_names, 285 composite_symbol_names, 286 opts): 287 """Functional form of a for statement. 288 289 The loop operates on a state, which includes all symbols that are 290 variant across loop iterations, excluding the iterate as well as the 291 variables local to the loop. 292 293 For example, given the loop below that calculates the geometric and 294 arithmetic means or some numbers: 295 296 geo_mean = 1 297 arith_mean = 0 298 for i in range(n): 299 a = numbers[i] 300 geo_mean *= a 301 arith_mean += a 302 303 The state is represented by the variables geo_mean and arith_mean. The 304 argument for initial_state may contain the tuple (1, 0), the body will 305 include the arguments geo_mean and arith_mean and will return a tuple 306 representing the new values for geo_mean and respectively arith_mean. 307 308 Args: 309 iter_: The entity being iterated over. 310 extra_test: Callable with the state as arguments, and boolean return type. 311 An additional loop condition. 312 body: Callable with the iterate and the state as arguments, and state as 313 return type. The actual loop body. 314 get_state: Additional callable which can capture additional state (such as 315 the values of composite symbols). This is only useful when staging the 316 loop. 317 set_state: Additional callable which save values captured by get_state back 318 into the Python environment. This is only useful when staging the loop. 319 init_vars: Tuple containing the initial state. 320 basic_symbol_names: Tuple containing basic loop var names. 321 composite_symbol_names: Tuple containing composite loop var names. 322 opts: Optional dict of extra loop parameters. 323 324 Returns: 325 Tuple containing the final state. 326 """ 327 if tensor_util.is_tf_type(iter_): 328 if tensors.is_range_tensor(iter_): 329 return _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, 330 init_vars, basic_symbol_names, 331 composite_symbol_names, opts) 332 else: 333 return _known_len_tf_for_stmt(iter_, extra_test, body, get_state, 334 set_state, init_vars, basic_symbol_names, 335 composite_symbol_names, opts) 336 337 if isinstance(iter_, dataset_ops.DatasetV2): 338 return _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state, 339 init_vars, basic_symbol_names, 340 composite_symbol_names, opts) 341 342 if isinstance(iter_, iterator_ops.OwnedIterator): 343 return _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state, 344 init_vars, basic_symbol_names, 345 composite_symbol_names, opts) 346 347 if isinstance(iter_, ragged_tensor.RaggedTensor): 348 return _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, 349 init_vars, basic_symbol_names, 350 composite_symbol_names, opts) 351 352 if isinstance(iter_, input_lib.DistributedIterator): 353 raise NotImplementedError( 354 'distributed iterators not supported yet, use the distributed dataset' 355 ' directly') 356 357 if isinstance(iter_, input_lib.DistributedDataset): 358 return _tf_distributed_dataset_for_stmt(iter_, extra_test, body, init_vars) 359 360 return _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars) 361 362 363def _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars): 364 """Overload of for_stmt that executes a Python for loop.""" 365 del get_state, set_state 366 state = init_vars 367 368 if extra_test is not None: 369 if extra_test(*state): 370 for target in iter_: 371 state = body(target, *state) 372 if not extra_test(*state): 373 break 374 375 else: 376 for target in iter_: 377 state = body(target, *state) 378 379 return state 380 381 382def _known_len_tf_for_stmt(iter_, 383 extra_test, 384 body, 385 get_state, 386 set_state, 387 init_vars, 388 basic_symbol_names, 389 composite_symbol_names, 390 opts): 391 """Overload of for_stmt that iterates over TF entities that admit a length.""" 392 _disallow_undefs_into_loop(*init_vars) 393 394 n = py_builtins.len_(iter_) 395 # TODO(b/117628877): Revisit performance once XLA has the necessary support. 396 # Note: using a TensorArray creates an extra copy, but can calculate 397 # gradients more efficiently than StridedSlice. 398 ta = tensor_array_ops.TensorArray(iter_.dtype, size=n) 399 iter_ = ta.unstack(iter_) 400 401 def while_body(iterate_index, *loop_vars): 402 """Main loop body.""" 403 iterate = iter_.read(iterate_index) 404 new_vars = body(iterate, *loop_vars) 405 406 loop_vars = (iterate_index + 1,) 407 if new_vars: 408 loop_vars += new_vars 409 410 return loop_vars 411 412 def while_cond(iterate_index, *loop_vars): 413 if extra_test is not None: 414 return control_flow_ops.cond(iterate_index < n, 415 lambda: extra_test(*loop_vars), 416 lambda: False) 417 return iterate_index < n 418 419 opts['maximum_iterations'] = n 420 421 results = _tf_while_stmt( 422 while_cond, 423 while_body, 424 get_state, 425 set_state, 426 (array_ops.zeros_like(n),) + init_vars, 427 ('<internal iterate>',) + basic_symbol_names, 428 composite_symbol_names, 429 opts, 430 ) 431 432 # Note: the iteration index is not returned by the while loop, however 433 # if a symbol with the same name exists outside the loop, it will be captured 434 # by the loop variables and ultimately updated correctly. 435 if isinstance(results, (tuple, list)): 436 assert len(results) >= 1 # Has at least the iterate. 437 if len(results) > 1: 438 results = results[1:] 439 else: 440 results = () 441 442 return results 443 444 445def _tf_ragged_for_stmt(iter_, 446 extra_test, 447 body, 448 get_state, 449 set_state, 450 init_vars, 451 basic_symbol_names, 452 composite_symbol_names, 453 opts): 454 """Overload of for_stmt that iterates over TF ragged tensors.""" 455 _disallow_undefs_into_loop(*init_vars) 456 457 # TODO(mdan): Move this into len()? Requires eager support. 458 if iter_.shape and iter_.shape[0] is not None: 459 n = iter_.shape[0] 460 else: 461 n = iter_.row_lengths()[0] 462 463 opts['maximum_iterations'] = n 464 465 def while_body(iterate_index, *loop_vars): 466 """Main loop body.""" 467 iterate = iter_[iterate_index] 468 new_vars = body(iterate, *loop_vars) 469 470 loop_vars = (iterate_index + 1,) 471 if new_vars: 472 loop_vars += new_vars 473 474 return loop_vars 475 476 def while_cond(iterate_index, *loop_vars): 477 if extra_test is not None: 478 return control_flow_ops.cond( 479 iterate_index < n, 480 lambda: extra_test(*loop_vars), 481 lambda: False, 482 ) 483 return iterate_index < n 484 485 opts['maximum_iterations'] = n 486 487 results = _tf_while_stmt( 488 while_cond, 489 while_body, 490 get_state, 491 set_state, 492 (array_ops.zeros_like(n),) + init_vars, 493 ('<internal iterate>',) + basic_symbol_names, 494 composite_symbol_names, 495 opts, 496 ) 497 498 if isinstance(results, (tuple, list)): 499 assert len(results) >= 1 # Has at least the iterate. 500 if len(results) > 1: 501 results = results[1:] 502 else: 503 results = () 504 505 return results 506 507 508def _tf_range_for_stmt(iter_, 509 extra_test, 510 body, 511 get_state, 512 set_state, 513 init_vars, 514 basic_symbol_names, 515 composite_symbol_names, 516 opts): 517 """Overload of for_stmt that iterates over a TF range (and elides it).""" 518 _disallow_undefs_into_loop(*init_vars) 519 520 start, limit, delta = iter_.op.inputs 521 522 def while_body(iterate, *loop_vars): 523 new_vars = body(iterate, *loop_vars) 524 loop_vars = (iterate + delta,) 525 526 if new_vars: 527 loop_vars += new_vars 528 529 return loop_vars 530 531 def while_cond(iterate, *loop_vars): 532 """Cond function for `tf.while_loop`.""" 533 main_test = math_ops.logical_or( 534 math_ops.logical_and(delta >= 0, iterate < limit), 535 math_ops.logical_and(delta < 0, iterate > limit)) 536 if extra_test is not None: 537 return control_flow_ops.cond( 538 main_test, 539 lambda: extra_test(*loop_vars), 540 lambda: False, 541 ) 542 return main_test 543 544 opts['maximum_iterations'] = math_ops.cast( 545 misc.get_range_len(start, limit, delta), dtypes.int32) 546 547 results = _tf_while_stmt( 548 while_cond, 549 while_body, 550 get_state, 551 set_state, 552 (start,) + init_vars, 553 ('<internal iterate>',) + basic_symbol_names, 554 composite_symbol_names, 555 opts, 556 ) 557 558 # Note: the iteration index is not returned by the while loop, however 559 # if a symbol with the same name exists outside the loop, it will be captured 560 # by the loop variables and ultimately updated correctly. 561 if isinstance(results, (tuple, list)): 562 assert len(results) >= 1 # Has at least the iterate. 563 if len(results) > 1: 564 results = results[1:] 565 else: 566 results = () 567 568 return results 569 570 571def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state, 572 init_vars, basic_symbol_names, 573 composite_symbol_names, opts): 574 """Overload of for_stmt that iterates over TF Iterators. See for_loop.""" 575 _disallow_undefs_into_loop(*init_vars) 576 577 def while_body_actual(opt_iterate, *loop_vars): 578 """Actual main loop body.""" 579 new_vars = body(opt_iterate.get_value(), *loop_vars) 580 # TODO(mdan): Fix this inconsistency in the converter. 581 if new_vars is None: 582 new_vars = () 583 # Note: this verification duplicates that perfrmed in tf_while_stmt, 584 # but needs to be done earlier to prevent the tf.cond inside while_body 585 # from blowing up first. 586 _verify_tf_loop_vars(init_vars, loop_vars, new_vars, 587 basic_symbol_names + composite_symbol_names, opts) 588 return new_vars 589 590 def while_body(has_next, *loop_vars): 591 """Main loop body.""" 592 opt_iterate = itr.get_next_as_optional() 593 has_next = opt_iterate.has_value() 594 595 if not init_vars: 596 # cond_v2 requires at least one state tensor in V1. 597 dummy_state = (constant_op.constant(()),) 598 else: 599 dummy_state = () 600 601 # TODO(mdan): If tf.while_loop supported Optional, this could be avoided. 602 new_vars = control_flow_ops.cond( 603 has_next, 604 lambda: dummy_state + while_body_actual(opt_iterate, *loop_vars), 605 lambda: dummy_state + loop_vars, 606 ) 607 608 if dummy_state: 609 new_vars = new_vars[1:] 610 611 return (has_next,) + new_vars 612 613 def while_cond(has_next, *loop_vars): 614 if extra_test is not None: 615 return control_flow_ops.cond( 616 has_next, 617 lambda: extra_test(*loop_vars), 618 lambda: False, 619 ) 620 return has_next 621 622 final_vars = _tf_while_stmt( 623 while_cond, 624 while_body, 625 get_state, 626 set_state, 627 (True,) + init_vars, 628 ('<internal has_next>',) + basic_symbol_names, 629 composite_symbol_names, 630 opts, 631 ) 632 return final_vars[1:] 633 634 635def _tf_dataset_for_stmt(ds, extra_test, body, get_state, set_state, init_vars, 636 basic_symbol_names, composite_symbol_names, opts): 637 """Overload of for_stmt that iterates over TF Datasets.""" 638 _disallow_undefs_into_loop(*init_vars) 639 640 if extra_test is not None: 641 assert init_vars, 'Lowering should always add state.' 642 return _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, 643 set_state, init_vars, 644 basic_symbol_names, 645 composite_symbol_names, opts) 646 647 return _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, 648 init_vars, basic_symbol_names, 649 composite_symbol_names, opts) 650 651 652def _general_purpose_scan(ds, init_state, body): 653 """Variant of Dataset.scan with semantics of general-purpose computation.""" 654 # Datasets are typically intended for data preprocessing. However, in 655 # autograph loops they usually appear as general-purpose computations (for 656 # example, a custom training loop). These two use cases require significantly 657 # different optimization policies, the most important of which is the device 658 # placement. The flag override for use_default_device below instructs the 659 # runtime to treat the computation as general-purpose, rather than data 660 # preprocessing. 661 # TODO(mdan): s/use_default_device/specialize_for_input_pipeline. 662 # TODO(mdan): Don't use private symbols. 663 # pylint:disable=protected-access 664 return dataset_ops._ScanDataset( 665 ds, init_state, body, use_default_device=False) 666 667 668def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, 669 set_state, init_vars, basic_symbol_names, 670 composite_symbol_names, opts): 671 """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" 672 673 # TODO(mdan): Simplify this - following it is extremely difficult. 674 675 init_state = get_state() 676 aug_init_vars = init_vars, init_state 677 678 def scan_body(aug_vars, iterate): 679 """The main loop body wrapper. Only calculates the stop condition.""" 680 loop_vars, state = aug_vars 681 682 def true_fn(): 683 """Main path - stop condition is not set.""" 684 set_state(state) 685 new_vars = body(iterate, *loop_vars) 686 new_state = get_state() 687 _verify_tf_loop_vars( 688 init_vars + init_state, 689 loop_vars + state, 690 new_vars + new_state, 691 basic_symbol_names + composite_symbol_names, 692 opts, 693 check_shapes=False) 694 return new_vars, new_state 695 696 extra_cond = extra_test(*loop_vars) 697 new_vars, new_state = control_flow_ops.cond( 698 extra_cond, 699 true_fn, 700 lambda: (loop_vars, state), 701 ) 702 703 scan_outputs = new_vars, new_state, extra_cond 704 # Note: new_aug_vars is the actual state of scan; scan_outputs is its output 705 # (hence the redundancy). 706 # get_state will pull any mutations that body may have made. 707 new_aug_vars = new_vars, new_state 708 return new_aug_vars, scan_outputs 709 710 def take_while_predicate(unused_loop_vars, unused_state, extra_cond): 711 return extra_cond 712 713 def reduce_body(unused_aug_vars, scan_outputs): 714 output_aug_vars, output_state, extra_cond = scan_outputs 715 del extra_cond 716 return output_aug_vars, output_state 717 718 ds = _general_purpose_scan(ds, aug_init_vars, scan_body) 719 ds = ds.apply(take_while_ops.take_while(take_while_predicate)) 720 final_aug_vars = ds.reduce(aug_init_vars, reduce_body) 721 final_vars, final_state = final_aug_vars 722 set_state(final_state) 723 return final_vars 724 725 726def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars, 727 basic_symbol_names, composite_symbol_names, 728 opts): 729 """Overload of _dataset_for_stmt without early stopping. See for_stmt.""" 730 init_state = get_state() 731 assert isinstance(init_vars, tuple) 732 assert isinstance(init_state, tuple) 733 734 symbol_names = basic_symbol_names + composite_symbol_names 735 736 # Workaround for Dataset.reduce not allowing empty state tensors - create 737 # a dummy state variable that remains unused. 738 # TODO(mdan): reduce should allow and match empty structures. 739 no_vars = not init_vars 740 no_state = not init_state 741 742 if no_vars: 743 init_vars = (constant_op.constant(0),) 744 symbol_names = ('<internal dummy>',) + symbol_names 745 if no_state: 746 init_state = (constant_op.constant(0),) 747 symbol_names = symbol_names + ('<internal dummy>',) 748 749 def scan_body(aug_vars, iterate): 750 """The main loop body wrapper.""" 751 loop_vars, state = aug_vars 752 if not no_state: 753 set_state(state) 754 755 if no_vars: 756 body(iterate) 757 new_vars = loop_vars 758 else: 759 new_vars = body(iterate, *loop_vars) 760 761 if no_state: 762 new_state = state 763 else: 764 new_state = get_state() 765 766 _verify_tf_loop_vars( 767 init_vars + init_state, 768 loop_vars + state, 769 new_vars + new_state, 770 symbol_names, 771 opts, 772 check_shapes=False) 773 774 scan_outputs = new_vars, new_state 775 # Note: new_aug_vars is the actual state of scan; scan_outputs is its output 776 # (hence the redundancy). 777 # get_state will pull any mutations that body may have made. 778 new_aug_vars = new_vars, new_state 779 return new_aug_vars, scan_outputs 780 781 def reduce_body(unused_aug_vars, scan_outputs): 782 output_aug_vars, output_state = scan_outputs 783 return output_aug_vars, output_state 784 785 aug_vars = init_vars, get_state() 786 ds = _general_purpose_scan(ds, aug_vars, scan_body) 787 final_vars, final_state = ds.reduce(aug_vars, reduce_body) 788 set_state(final_state) 789 790 if no_vars: 791 return () 792 return final_vars 793 794 795def _tf_distributed_dataset_for_stmt(iter_, extra_test, body, init_state): 796 """Overload of for..in statement that iterates over the input.""" 797 _disallow_undefs_into_loop(*init_state) 798 799 if extra_test is not None: 800 raise NotImplementedError( 801 'break and return statements are not yet supported in ' 802 'for ... in distributed input loops.') 803 804 def reduce_body(state, iterate): 805 new_state = body(iterate, *state) 806 return new_state 807 808 if init_state: 809 return iter_.reduce(init_state, reduce_body) 810 811 def reduce_body_with_dummy_state(state, iterate): 812 reduce_body((), iterate) 813 return state 814 iter_.reduce((constant_op.constant(0),), reduce_body_with_dummy_state) 815 return () 816 817 818def while_stmt(test, 819 body, 820 get_state, 821 set_state, 822 init_vars, 823 basic_symbol_names, 824 composite_symbol_names, 825 opts): 826 """Functional form of a while statement. 827 828 The loop operates on a so-called state, which includes all symbols that are 829 variant across loop iterations. In what follows we refer to state as either 830 a tuple of entities that represent an actual state, or a list of arguments 831 of the corresponding types. 832 833 Args: 834 test: Callable with the state as arguments, and boolean return type. The 835 loop condition. 836 body: Callable with the state as arguments, and state as return type. The 837 actual loop body. 838 get_state: Additional callable which can capture additional state (such as 839 the values of composite symbols). This is only useful when staging the 840 loop. 841 set_state: Additional callable which save values captured by get_state back 842 into the Python environment. This is only useful when staging the loop. 843 init_vars: Tuple containing the initial state. 844 basic_symbol_names: Tuple containing basic loop var names. 845 composite_symbol_names: Tuple containing composite loop var names. 846 opts: Optional dict of extra loop parameters. 847 848 Returns: 849 Tuple containing the final state. 850 """ 851 852 # Evaluate the initial test once in order to do the dispatch. The evaluation 853 # is isolated to minimize unwanted side effects. 854 # TODO(mdan): Do a full iteration - some state types might lower to Tensor. 855 with func_graph.FuncGraph('tmp').as_default(): 856 init_test = test(*init_vars) 857 858 # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine 859 # with the re-evaluation of `test` that `_tf_while_stmt` will make. 860 if tensors.is_dense_tensor(init_test): 861 return _tf_while_stmt(test, body, get_state, set_state, init_vars, 862 basic_symbol_names, composite_symbol_names, opts) 863 864 # Normal Python: We already consumed one evaluation of `test`; consistently, 865 # unroll one iteration before dispatching to a normal loop. 866 # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt? 867 if not init_test: 868 return init_vars 869 init_vars = body(*init_vars) 870 871 return _py_while_stmt(test, body, get_state, set_state, init_vars, opts) 872 873 874def _shape_invariants_mapping_to_positional_list(mapping, keys): 875 # The keys are not expected to be hashable. 876 mapping = {id(k): (k, v) for k, v in mapping} 877 result = [] 878 for k in keys: 879 map_key, map_val = mapping.get(id(k), (None, None)) 880 result.append(map_val if map_key is k else None) 881 return tuple(result) 882 883 884def _tf_while_stmt(test, body, get_state, set_state, init_vars, 885 basic_symbol_names, composite_symbol_names, opts): 886 """Overload of while_stmt that stages a TF while_stmt.""" 887 _disallow_undefs_into_loop(*init_vars) 888 889 aug_init_vars = init_vars + get_state() 890 891 # TODO(mdan): Simplify this. 892 loop_vars_slice = slice(len(init_vars)) 893 state_slice = slice(len(init_vars), None) 894 895 def aug_test(*aug_loop_vars): 896 state = aug_loop_vars[state_slice] 897 set_state(state) 898 return test(*aug_loop_vars[loop_vars_slice]) 899 900 def aug_body(*aug_loop_vars): 901 """Main loop body.""" 902 state = aug_loop_vars[state_slice] 903 set_state(state) 904 loop_vars = body(*aug_loop_vars[loop_vars_slice]) 905 new_state = loop_vars + get_state() 906 _verify_tf_loop_vars(aug_init_vars, aug_loop_vars, new_state, 907 basic_symbol_names + composite_symbol_names, opts) 908 909 return new_state 910 911 # Non-v2 while_loop unpacks the results when there is only one return value. 912 # This enforces consistency across versions. 913 opts['return_same_structure'] = True 914 915 if 'shape_invariants' in opts: 916 opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list( 917 opts['shape_invariants'], aug_init_vars) 918 919 final_aug_vars = control_flow_ops.while_loop(aug_test, aug_body, 920 aug_init_vars, **opts) 921 final_state = final_aug_vars[state_slice] 922 set_state(final_state) 923 return final_aug_vars[loop_vars_slice] 924 925 926class _PythonLoopChecker(object): 927 """Verifies Python loops for TF-specific limits.""" 928 929 def __init__(self): 930 self.iterations = 0 931 self.check_inefficient_unroll = WARN_INEFFICIENT_UNROLL 932 933 # Triggered when we decided to test the op counts. 934 self.check_op_count_after_iteration = False 935 936 def _get_ops(self): 937 return ops.get_default_graph().get_operations() 938 939 def _check_unroll_limits(self): 940 if LIMIT_PYTHON_ITERATIONS and self.iterations > PYTHON_MAX_ITERATIONS: 941 raise ValueError('iteration limit exceeded') 942 943 def _stop_checking_inefficient_unroll(self): 944 self.check_inefficient_unroll = False 945 self.ops_before_iteration = None 946 947 def _verify_ineffcient_unroll(self): 948 """Checks for possibly-inefficient creation of ops in a Python loop.""" 949 assert self.ops_before_iteration is not None 950 ops_after_iteration = self._get_ops() 951 new_ops = tuple( 952 op for op in ops_after_iteration if op not in self.ops_before_iteration) 953 954 if len(new_ops) < INEFFICIENT_UNROLL_MIN_OPS: 955 return False 956 957 # TODO(mdan): Add location information. 958 ag_logging.warn( 959 'TensorFlow ops are being created in a Python loop with large number' 960 ' of iterations. This can lead to slow startup. Did you mean to use a' 961 ' TensorFlow loop? For example, `while True:` is a Python loop, and' 962 ' `while tf.constant(True):` is a TensorFlow loop. The following' 963 ' ops were created after iteration %s: %s', self.iterations, new_ops) 964 return True 965 966 def before_iteration(self): 967 """Called before each iteration in a Python loop.""" 968 if (self.check_inefficient_unroll and 969 self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS): 970 self.ops_before_iteration = self._get_ops() 971 self.check_op_count_after_iteration = True 972 973 def after_iteration(self): 974 """Called after each iteration in a Python loop.""" 975 self.iterations += 1 976 977 self._check_unroll_limits() 978 979 if self.check_inefficient_unroll and self.check_op_count_after_iteration: 980 did_warn = self._verify_ineffcient_unroll() 981 if did_warn: 982 self._stop_checking_inefficient_unroll() # Only warn once. 983 elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3: 984 # Once deciding to check the op counts, only do it for a few iterations. 985 self._stop_checking_inefficient_unroll() 986 987 988def _py_while_stmt(test, body, get_state, set_state, init_vars, opts): 989 """Overload of while_stmt that executes a Python while loop.""" 990 del opts, get_state, set_state 991 992 if __debug__: 993 checker = _PythonLoopChecker() 994 995 loop_vars = init_vars 996 while test(*loop_vars): 997 998 if __debug__: 999 checker.before_iteration() 1000 1001 loop_vars = body(*loop_vars) 1002 1003 if __debug__: 1004 checker.after_iteration() 1005 1006 return loop_vars 1007 1008 1009def if_stmt(cond, 1010 body, 1011 orelse, 1012 get_state, 1013 set_state, 1014 basic_symbol_names, 1015 composite_symbol_names): 1016 """Functional form of an if statement. 1017 1018 Args: 1019 cond: Boolean. 1020 body: Callable with no arguments, and outputs of the positive (if) branch as 1021 return type. 1022 orelse: Callable with no arguments, and outputs of the negative (else) 1023 branch as return type. 1024 get_state: Function that returns a tuple containing the values of all 1025 composite symbols modified within the conditional. This allows access to 1026 state that branches may mutate through side effects. This function is not 1027 needed and should not be called when dispatching to code matching Python's 1028 default semantics. This is useful for checkpointing to avoid unintended 1029 side-effects when staging requires evaluating all code-paths. 1030 set_state: Function to set the values of all composite symbols modified 1031 within the conditional. This is the complement to get_state, used to 1032 restore checkpointed values. The single argument a tuple containing values 1033 for each composite symbol that may be modified in a branch of the 1034 conditional. The is usually the result of a call to get_state. 1035 basic_symbol_names: Tuple containing basic loop var names. 1036 composite_symbol_names: Tuple containing composite loop var names. 1037 1038 Returns: 1039 Tuple containing the statement outputs. 1040 """ 1041 # Note: tf.cond doesn't support SparseTensor. 1042 if tensors.is_dense_tensor(cond): 1043 return tf_if_stmt(cond, body, orelse, get_state, set_state, 1044 basic_symbol_names, composite_symbol_names) 1045 else: 1046 return _py_if_stmt(cond, body, orelse) 1047 1048 1049def tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names, 1050 composite_symbol_names): 1051 """Overload of if_stmt that stages a TF cond.""" 1052 body = _wrap_disallow_undefs_from_cond(body, branch_name='if') 1053 orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else') 1054 body = _isolate_state(body, get_state, set_state) 1055 orelse = _isolate_state(orelse, get_state, set_state) 1056 1057 # `state` currently includes the values of any composite symbols (e.g. `a.b`) 1058 # composites modified by the loop. `final_vars` includes the values of basic 1059 # symbols (e.g. `a`) which cannot be passed by reference and must be returned. 1060 # See _isolate_state. 1061 # TODO(mdan): We should minimize calls to get/set_state. 1062 1063 body_branch = 0 1064 orelse_branch = 1 1065 result = [None, None] 1066 1067 def error_checking_body(): 1068 result[body_branch] = body() 1069 if result[orelse_branch] is not None: 1070 _verify_tf_cond_vars(result[body_branch], result[orelse_branch], 1071 basic_symbol_names + composite_symbol_names) 1072 return result[body_branch] 1073 1074 def error_checking_orelse(): 1075 result[orelse_branch] = orelse() 1076 if result[body_branch] is not None: 1077 _verify_tf_cond_vars(result[body_branch], result[orelse_branch], 1078 basic_symbol_names + composite_symbol_names) 1079 return result[orelse_branch] 1080 1081 final_vars, final_state = control_flow_ops.cond(cond, error_checking_body, 1082 error_checking_orelse) 1083 1084 set_state(final_state) 1085 1086 return final_vars 1087 1088 1089def _isolate_state(func, get_state, set_state): 1090 """Wraps func to (best-effort) isolate state mutations that func may do. 1091 1092 The simplest example of state mutation is mutation of variables (via e.g. 1093 attributes), or modification of globals. 1094 1095 This allows us to more safely execute this function without worrying about 1096 side effects when the function wasn't normally expected to execute. For 1097 example, staging requires that the function is executed ahead of time, and 1098 we need to ensure its effects are not observed during normal execution. 1099 1100 Args: 1101 func: () -> Any 1102 get_state: () -> Any, returns the current state 1103 set_state: (Any) -> None, resets the state to the specified values. 1104 Typically the result of an earlier call to `get_state`. 1105 1106 Returns: 1107 Tuple[Any, Any], where the first element is the return value of `func`, 1108 and the second is the final state values. 1109 """ 1110 1111 def wrapper(): 1112 init_state = get_state() 1113 new_vars = func() 1114 # TODO(mdan): These should be copies, lest set_state might affect them. 1115 new_state = get_state() 1116 set_state(init_state) 1117 return new_vars, new_state 1118 1119 return wrapper 1120 1121 1122def _wrap_disallow_undefs_from_cond(func, branch_name): 1123 """Wraps conditional branch to disallow returning undefined symbols.""" 1124 1125 def wrapper(): 1126 """Calls function and raises an error if undefined symbols are returned.""" 1127 results = func() 1128 1129 if isinstance(results, tuple): 1130 results_tuple = results 1131 else: 1132 results_tuple = results, 1133 undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)] 1134 if undefined: 1135 raise ValueError( 1136 'The following symbols must also be initialized in the {} branch: {}.' 1137 ' Alternatively, you may initialize them before the if' 1138 ' statement.'.format(branch_name, 1139 tuple(s.symbol_name for s in undefined))) 1140 1141 for result in results_tuple: 1142 if isinstance(result, variables.UndefinedReturnValue): 1143 raise ValueError( 1144 'A value must also be returned from the {} branch. If a value is ' 1145 'returned from one branch of a conditional a value must be ' 1146 'returned from all branches.'.format(branch_name)) 1147 1148 return results 1149 1150 return wrapper 1151 1152 1153def _py_if_stmt(cond, body, orelse): 1154 """Overload of if_stmt that executes a Python if statement.""" 1155 return body() if cond else orelse() 1156