1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Handles control flow statements: while, for, if. 16 17Python 2 compatibility version. Not maintained. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import gast 25 26from tensorflow.python.autograph.core import converter 27from tensorflow.python.autograph.lang import directives 28from tensorflow.python.autograph.pyct import anno 29from tensorflow.python.autograph.pyct import ast_util 30from tensorflow.python.autograph.pyct import cfg 31from tensorflow.python.autograph.pyct import parser 32from tensorflow.python.autograph.pyct import qual_names 33from tensorflow.python.autograph.pyct import templates 34from tensorflow.python.autograph.pyct.static_analysis import activity 35from tensorflow.python.autograph.pyct.static_analysis import annos 36from tensorflow.python.autograph.pyct.static_analysis import liveness 37from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 38from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs 39 40 41# TODO(mdan): Refactor functions to make them smaller. 42 43 44class ControlFlowTransformer(converter.Base): 45 """Transforms control flow structures like loops an conditionals.""" 46 47 def _create_cond_branch(self, body_name, aliased_orig_names, 48 aliased_new_names, body, returns): 49 if len(returns) == 1: 50 template = """ 51 return retval 52 """ 53 return_stmt = templates.replace(template, retval=returns[0]) 54 else: 55 template = """ 56 return (retvals,) 57 """ 58 return_stmt = templates.replace(template, retvals=returns) 59 60 if aliased_orig_names: 61 alias_declarations = [] 62 for new_name, old_name in zip(aliased_new_names, aliased_orig_names): 63 template = """ 64 try: 65 aliased_new_name = aliased_orig_name 66 except NameError: 67 aliased_new_name = ag__.Undefined(symbol_name) 68 """ 69 70 alias_declarations.extend( 71 templates.replace( 72 template, 73 aliased_new_name=new_name, 74 aliased_orig_name=old_name, 75 symbol_name=gast.Constant(str(old_name), kind=None))) 76 77 template = """ 78 def body_name(): 79 alias_declarations 80 body 81 return_stmt 82 """ 83 return templates.replace( 84 template, 85 alias_declarations=alias_declarations, 86 body_name=body_name, 87 body=body, 88 return_stmt=return_stmt) 89 else: 90 template = """ 91 def body_name(): 92 body 93 return_stmt 94 """ 95 return templates.replace( 96 template, body_name=body_name, body=body, return_stmt=return_stmt) 97 98 def _create_cond_expr(self, results, test, body_name, orelse_name, 99 state_getter_name, state_setter_name, 100 basic_symbol_names, composite_symbol_names): 101 if results is not None: 102 template = """ 103 results = ag__.if_stmt(test, body_name, orelse_name, 104 state_getter_name, state_setter_name, 105 (basic_symbol_names,), 106 (composite_symbol_names,)) 107 """ 108 return templates.replace( 109 template, 110 test=test, 111 results=results, 112 body_name=body_name, 113 orelse_name=orelse_name, 114 state_getter_name=state_getter_name, 115 state_setter_name=state_setter_name, 116 basic_symbol_names=basic_symbol_names, 117 composite_symbol_names=composite_symbol_names) 118 else: 119 template = """ 120 ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name, 121 (basic_symbol_names,), (composite_symbol_names,)) 122 """ 123 return templates.replace( 124 template, 125 test=test, 126 body_name=body_name, 127 orelse_name=orelse_name, 128 getter_name=state_getter_name, 129 setter_name=state_setter_name, 130 basic_symbol_names=basic_symbol_names, 131 composite_symbol_names=composite_symbol_names) 132 133 def _fmt_symbols(self, symbol_set): 134 if not symbol_set: 135 return 'no variables' 136 return ', '.join(map(str, symbol_set)) 137 138 def _determine_aliased_symbols(self, scope, node_defined_in): 139 modified_live = scope.modified & node_defined_in 140 # Composite symbols are handled elsewhere see _create_state_functions 141 return {s for s in modified_live if not s.is_composite()} 142 143 def _create_state_functions(self, composites, state_getter_name, 144 state_setter_name): 145 146 if composites: 147 composite_tuple = tuple(composites) 148 149 template = """ 150 def state_getter_name(): 151 return composite_tuple, 152 def state_setter_name(vals): 153 composite_tuple, = vals 154 """ 155 node = templates.replace( 156 template, 157 state_getter_name=state_getter_name, 158 state_setter_name=state_setter_name, 159 composite_tuple=composite_tuple) 160 else: 161 template = """ 162 def state_getter_name(): 163 return () 164 def state_setter_name(_): 165 pass 166 """ 167 node = templates.replace( 168 template, 169 state_getter_name=state_getter_name, 170 state_setter_name=state_setter_name) 171 172 return node 173 174 def _create_loop_options(self, node): 175 if not anno.hasanno(node, anno.Basic.DIRECTIVES): 176 return gast.Dict([], []) 177 178 loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES) 179 if directives.set_loop_options not in loop_directives: 180 return gast.Dict([], []) 181 182 opts_dict = loop_directives[directives.set_loop_options] 183 str_keys, values = zip(*opts_dict.items()) 184 keys = [gast.Constant(s, kind=None) for s in str_keys] 185 values = list(values) # ast and gast don't play well with tuples. 186 return gast.Dict(keys, values) 187 188 def _create_undefined_assigns(self, undefined_symbols): 189 assignments = [] 190 for s in undefined_symbols: 191 template = ''' 192 var = ag__.Undefined(symbol_name) 193 ''' 194 assignments += templates.replace( 195 template, 196 var=s, 197 symbol_name=gast.Constant(s.ssf(), kind=None)) 198 return assignments 199 200 def visit_If(self, node): 201 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 202 orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) 203 defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) 204 live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) 205 206 # Note: this information needs to be extracted before the body conversion 207 # that happens in the call to generic_visit below, because the conversion 208 # generates nodes that lack static analysis annotations. 209 need_alias_in_body = self._determine_aliased_symbols( 210 body_scope, defined_in) 211 need_alias_in_orelse = self._determine_aliased_symbols( 212 orelse_scope, defined_in) 213 214 node = self.generic_visit(node) 215 216 modified_in_cond = body_scope.modified | orelse_scope.modified 217 returned_from_cond = set() 218 composites = set() 219 for s in modified_in_cond: 220 if s in live_out and not s.is_composite(): 221 returned_from_cond.add(s) 222 if s.is_composite(): 223 # Special treatment for compound objects, always return them. 224 # This allows special handling within the if_stmt itself. 225 # For example, in TensorFlow we need to restore the state of composite 226 # symbols to ensure that only effects from the executed branch are seen. 227 composites.add(s) 228 229 created_in_body = body_scope.modified & returned_from_cond - defined_in 230 created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in 231 232 basic_created_in_body = tuple( 233 s for s in created_in_body if not s.is_composite()) 234 basic_created_in_orelse = tuple( 235 s for s in created_in_orelse if not s.is_composite()) 236 237 # These variables are defined only in a single branch. This is fine in 238 # Python so we pass them through. Another backend, e.g. Tensorflow, may need 239 # to handle these cases specially or throw an Error. 240 possibly_undefined = (set(basic_created_in_body) ^ 241 set(basic_created_in_orelse)) 242 243 # Alias the closure variables inside the conditional functions, to allow 244 # the functions access to the respective variables. 245 # We will alias variables independently for body and orelse scope, 246 # because different branches might write different variables. 247 aliased_body_orig_names = tuple(need_alias_in_body) 248 aliased_orelse_orig_names = tuple(need_alias_in_orelse) 249 aliased_body_new_names = tuple( 250 self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) 251 for s in aliased_body_orig_names) 252 aliased_orelse_new_names = tuple( 253 self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) 254 for s in aliased_orelse_orig_names) 255 256 alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) 257 alias_orelse_map = dict( 258 zip(aliased_orelse_orig_names, aliased_orelse_new_names)) 259 260 node_body = ast_util.rename_symbols(node.body, alias_body_map) 261 node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) 262 263 cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) 264 body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) 265 orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) 266 all_referenced = body_scope.referenced | orelse_scope.referenced 267 state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced) 268 state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced) 269 270 returned_from_cond = tuple(returned_from_cond) 271 composites = tuple(composites) 272 273 if returned_from_cond: 274 if len(returned_from_cond) == 1: 275 cond_results = returned_from_cond[0] 276 else: 277 cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) 278 279 returned_from_body = tuple( 280 alias_body_map[s] if s in need_alias_in_body else s 281 for s in returned_from_cond) 282 returned_from_orelse = tuple( 283 alias_orelse_map[s] if s in need_alias_in_orelse else s 284 for s in returned_from_cond) 285 286 else: 287 # When the cond would return no value, we leave the cond called without 288 # results. That in turn should trigger the side effect guards. The 289 # branch functions will return a dummy value that ensures cond 290 # actually has some return value as well. 291 cond_results = None 292 # TODO(mdan): Replace with None once side_effect_guards is retired. 293 returned_from_body = (templates.replace_as_expression( 294 'ag__.match_staging_level(1, cond_var_name)', 295 cond_var_name=cond_var_name),) 296 returned_from_orelse = (templates.replace_as_expression( 297 'ag__.match_staging_level(1, cond_var_name)', 298 cond_var_name=cond_var_name),) 299 300 cond_assign = self.create_assignment(cond_var_name, node.test) 301 body_def = self._create_cond_branch( 302 body_name, 303 aliased_orig_names=aliased_body_orig_names, 304 aliased_new_names=aliased_body_new_names, 305 body=node_body, 306 returns=returned_from_body) 307 orelse_def = self._create_cond_branch( 308 orelse_name, 309 aliased_orig_names=aliased_orelse_orig_names, 310 aliased_new_names=aliased_orelse_new_names, 311 body=node_orelse, 312 returns=returned_from_orelse) 313 undefined_assigns = self._create_undefined_assigns(possibly_undefined) 314 composite_defs = self._create_state_functions( 315 composites, state_getter_name, state_setter_name) 316 317 basic_symbol_names = tuple( 318 gast.Constant(str(symbol), kind=None) for symbol in returned_from_cond) 319 composite_symbol_names = tuple( 320 gast.Constant(str(symbol), kind=None) for symbol in composites) 321 322 cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, 323 orelse_name, state_getter_name, 324 state_setter_name, basic_symbol_names, 325 composite_symbol_names) 326 327 if_ast = ( 328 undefined_assigns + composite_defs + body_def + orelse_def + 329 cond_assign + cond_expr) 330 return if_ast 331 332 def _get_basic_loop_vars(self, modified_symbols, live_in, live_out): 333 # The loop variables corresponding to simple symbols (e.g. `x`). 334 basic_loop_vars = [] 335 for s in modified_symbols: 336 if s.is_composite(): 337 # TODO(mdan): Raise an error when this happens for a TF loop. 338 continue 339 # Variables not live into or out of the loop are considered local to the 340 # loop. 341 if s not in live_in and s not in live_out: 342 continue 343 basic_loop_vars.append(s) 344 return frozenset(basic_loop_vars) 345 346 def _get_composite_loop_vars(self, modified_symbols, live_in): 347 # The loop variables corresponding to composite symbols (e.g. `self.x`). 348 composite_loop_vars = [] 349 for s in modified_symbols: 350 if not s.is_composite(): 351 continue 352 # Mutations made to objects created inside the loop will appear as writes 353 # to composite symbols. Because these mutations appear as modifications 354 # made to composite symbols, we check whether the composite's parent is 355 # actually live into the loop. 356 # Example: 357 # while cond: 358 # x = Foo() 359 # x.foo = 2 * x.foo # x.foo is live into the loop, but x is not. 360 # 361 # Note that some parents might not be symbols - for example, in x['foo'], 362 # 'foo' is a parent, but it's a literal, not a symbol. We don't check the 363 # liveness of literals. 364 support_set_symbols = tuple( 365 sss for sss in s.support_set if sss.is_symbol()) 366 if not all(sss in live_in for sss in support_set_symbols): 367 continue 368 composite_loop_vars.append(s) 369 return frozenset(composite_loop_vars) 370 371 def _get_loop_vars(self, node, modified_symbols): 372 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 373 defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) 374 live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) 375 live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) 376 reserved_symbols = body_scope.referenced 377 378 basic_loop_vars = self._get_basic_loop_vars( 379 modified_symbols, live_in, live_out) 380 composite_loop_vars = self._get_composite_loop_vars( 381 modified_symbols, live_in) 382 383 # Variable that are used or defined inside the loop, but not defined 384 # before entering the loop. Only simple variables must be defined. The 385 # composite ones will be implicitly checked at runtime. 386 undefined_lives = basic_loop_vars - defined_in 387 388 return (basic_loop_vars, composite_loop_vars, reserved_symbols, 389 undefined_lives) 390 391 def _loop_var_constructs(self, basic_loop_vars): 392 loop_vars = tuple(basic_loop_vars) 393 loop_vars_ast_tuple = gast.Tuple([n.ast() for n in loop_vars], None) 394 395 if len(loop_vars) == 1: 396 loop_vars = loop_vars[0] 397 398 return loop_vars, loop_vars_ast_tuple 399 400 def visit_While(self, node): 401 node = self.generic_visit(node) 402 403 (basic_loop_vars, composite_loop_vars, reserved_symbols, 404 possibly_undefs) = self._get_loop_vars( 405 node, 406 anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified) 407 loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( 408 basic_loop_vars) 409 410 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) 411 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) 412 state_functions = self._create_state_functions( 413 composite_loop_vars, state_getter_name, state_setter_name) 414 415 basic_symbol_names = tuple( 416 gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars) 417 composite_symbol_names = tuple( 418 gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars) 419 420 opts = self._create_loop_options(node) 421 422 # TODO(mdan): Use a single template. 423 # If the body and test functions took a single tuple for loop_vars, instead 424 # of *loop_vars, then a single template could be used. 425 if loop_vars: 426 template = """ 427 state_functions 428 def body_name(loop_vars): 429 body 430 return loop_vars, 431 def test_name(loop_vars): 432 return test 433 loop_vars_ast_tuple = ag__.while_stmt( 434 test_name, 435 body_name, 436 state_getter_name, 437 state_setter_name, 438 (loop_vars,), 439 (basic_symbol_names,), 440 (composite_symbol_names,), 441 opts) 442 """ 443 node = templates.replace( 444 template, 445 loop_vars=loop_vars, 446 loop_vars_ast_tuple=loop_vars_ast_tuple, 447 test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), 448 test=node.test, 449 body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), 450 body=node.body, 451 state_functions=state_functions, 452 state_getter_name=state_getter_name, 453 state_setter_name=state_setter_name, 454 basic_symbol_names=basic_symbol_names, 455 composite_symbol_names=composite_symbol_names, 456 opts=opts) 457 else: 458 template = """ 459 state_functions 460 def body_name(): 461 body 462 return () 463 def test_name(): 464 return test 465 ag__.while_stmt( 466 test_name, 467 body_name, 468 state_getter_name, 469 state_setter_name, 470 (), 471 (), 472 (composite_symbol_names,), 473 opts) 474 """ 475 node = templates.replace( 476 template, 477 test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), 478 test=node.test, 479 body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), 480 body=node.body, 481 state_functions=state_functions, 482 state_getter_name=state_getter_name, 483 state_setter_name=state_setter_name, 484 composite_symbol_names=composite_symbol_names, 485 opts=opts) 486 487 undefined_assigns = self._create_undefined_assigns(possibly_undefs) 488 return undefined_assigns + node 489 490 def visit_For(self, node): 491 node = self.generic_visit(node) 492 493 (basic_loop_vars, composite_loop_vars, 494 reserved_symbols, possibly_undefs) = self._get_loop_vars( 495 node, (anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified 496 | anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE).modified)) 497 loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( 498 basic_loop_vars) 499 body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols) 500 501 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) 502 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) 503 state_functions = self._create_state_functions( 504 composite_loop_vars, state_getter_name, state_setter_name) 505 506 if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): 507 extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) 508 extra_test_name = self.ctx.namer.new_symbol( 509 'extra_test', reserved_symbols) 510 template = """ 511 def extra_test_name(loop_vars): 512 return extra_test_expr 513 """ 514 extra_test_function = templates.replace( 515 template, 516 extra_test_name=extra_test_name, 517 loop_vars=loop_vars, 518 extra_test_expr=extra_test) 519 else: 520 extra_test_name = parser.parse_expression('None') 521 extra_test_function = [] 522 523 # Workaround for PEP-3113 524 # iterates_var holds a single variable with the iterates, which may be a 525 # tuple. 526 iterates_var_name = self.ctx.namer.new_symbol( 527 'iterates', reserved_symbols) 528 template = """ 529 iterates = iterates_var_name 530 """ 531 iterate_expansion = templates.replace( 532 template, 533 iterates=node.target, 534 iterates_var_name=iterates_var_name) 535 536 undefined_assigns = self._create_undefined_assigns(possibly_undefs) 537 538 basic_symbol_names = tuple( 539 gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars) 540 composite_symbol_names = tuple( 541 gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars) 542 543 opts = self._create_loop_options(node) 544 545 # TODO(mdan): Use a single template. 546 # If the body and test functions took a single tuple for loop_vars, instead 547 # of *loop_vars, then a single template could be used. 548 if loop_vars: 549 template = """ 550 undefined_assigns 551 state_functions 552 def body_name(iterates_var_name, loop_vars): 553 iterate_expansion 554 body 555 return loop_vars, 556 extra_test_function 557 loop_vars_ast_tuple = ag__.for_stmt( 558 iter_, 559 extra_test_name, 560 body_name, 561 state_getter_name, 562 state_setter_name, 563 (loop_vars,), 564 (basic_symbol_names,), 565 (composite_symbol_names,), 566 opts) 567 """ 568 return templates.replace( 569 template, 570 undefined_assigns=undefined_assigns, 571 loop_vars=loop_vars, 572 loop_vars_ast_tuple=loop_vars_ast_tuple, 573 iter_=node.iter, 574 iterate_expansion=iterate_expansion, 575 iterates_var_name=iterates_var_name, 576 extra_test_name=extra_test_name, 577 extra_test_function=extra_test_function, 578 body_name=body_name, 579 body=node.body, 580 state_functions=state_functions, 581 state_getter_name=state_getter_name, 582 state_setter_name=state_setter_name, 583 basic_symbol_names=basic_symbol_names, 584 composite_symbol_names=composite_symbol_names, 585 opts=opts) 586 else: 587 template = """ 588 undefined_assigns 589 state_functions 590 def body_name(iterates_var_name): 591 iterate_expansion 592 body 593 return () 594 extra_test_function 595 ag__.for_stmt( 596 iter_, 597 extra_test_name, 598 body_name, 599 state_getter_name, 600 state_setter_name, 601 (), 602 (), 603 (composite_symbol_names,), 604 opts) 605 """ 606 return templates.replace( 607 template, 608 undefined_assigns=undefined_assigns, 609 iter_=node.iter, 610 iterate_expansion=iterate_expansion, 611 iterates_var_name=iterates_var_name, 612 extra_test_name=extra_test_name, 613 extra_test_function=extra_test_function, 614 body_name=body_name, 615 body=node.body, 616 state_functions=state_functions, 617 state_getter_name=state_getter_name, 618 state_setter_name=state_setter_name, 619 composite_symbol_names=composite_symbol_names, 620 opts=opts) 621 622 623class AnnotatedDef(reaching_definitions.Definition): 624 625 def __init__(self): 626 super(AnnotatedDef, self).__init__() 627 self.directives = {} 628 629 630def transform(node, ctx): 631 graphs = cfg.build(node) 632 node = qual_names.resolve(node) 633 node = activity.resolve(node, ctx, None) 634 node = reaching_definitions.resolve(node, ctx, graphs) 635 node = reaching_fndefs.resolve(node, ctx, graphs) 636 node = liveness.resolve(node, ctx, graphs) 637 638 node = ControlFlowTransformer(ctx).visit(node) 639 return node 640