1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Handles control flow statements: while, for, if.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import gast 22 23from tensorflow.python.autograph.core import converter 24from tensorflow.python.autograph.lang import directives 25from tensorflow.python.autograph.pyct import anno 26from tensorflow.python.autograph.pyct import cfg 27from tensorflow.python.autograph.pyct import origin_info 28from tensorflow.python.autograph.pyct import parser 29from tensorflow.python.autograph.pyct import qual_names 30from tensorflow.python.autograph.pyct import templates 31from tensorflow.python.autograph.pyct.static_analysis import activity 32from tensorflow.python.autograph.pyct.static_analysis import annos 33from tensorflow.python.autograph.pyct.static_analysis import liveness 34from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 35from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs 36 37 38class _Function(object): 39 40 scope = None 41 42 43class ControlFlowTransformer(converter.Base): 44 """Transforms control flow structures like loops an conditionals.""" 45 46 def visit_Lambda(self, node): 47 with self.state[_Function] as fn: 48 fn.scope = anno.getanno(node, anno.Static.SCOPE) 49 return self.generic_visit(node) 50 51 def visit_FunctionDef(self, node): 52 with self.state[_Function] as fn: 53 fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 54 return self.generic_visit(node) 55 56 def _create_nonlocal_declarations(self, vars_): 57 vars_ = set(vars_) 58 results = [] 59 global_vars = self.state[_Function].scope.globals & vars_ 60 61 if global_vars: 62 results.append(gast.Global([str(v) for v in global_vars])) 63 64 nonlocal_vars = [ 65 v for v in vars_ if not v.is_composite() and v not in global_vars] 66 if nonlocal_vars: 67 results.append(gast.Nonlocal([str(v) for v in nonlocal_vars])) 68 69 return results 70 71 def _create_state_functions( 72 self, block_vars, nonlocal_declarations, getter_name, setter_name): 73 if not block_vars: 74 template = """ 75 def getter_name(): 76 return () 77 def setter_name(block_vars): 78 pass 79 """ 80 return templates.replace( 81 template, getter_name=getter_name, setter_name=setter_name) 82 83 guarded_block_vars = [] 84 for v in block_vars: 85 if v.is_simple(): 86 guarded_block_vars.append(v) 87 else: 88 guarded_block_vars.append( 89 templates.replace_as_expression( 90 'ag__.ldu(lambda: var_, name)', 91 var_=v, 92 name=gast.Constant(str(v), kind=None))) 93 94 template = """ 95 def getter_name(): 96 return guarded_state_vars, 97 def setter_name(vars_): 98 nonlocal_declarations 99 state_vars, = vars_ 100 """ 101 return templates.replace( 102 template, 103 nonlocal_declarations=nonlocal_declarations, 104 getter_name=getter_name, 105 guarded_state_vars=guarded_block_vars, 106 setter_name=setter_name, 107 state_vars=tuple(block_vars)) 108 109 def _create_loop_options(self, node): 110 if not anno.hasanno(node, anno.Basic.DIRECTIVES): 111 return gast.Dict([], []) 112 113 loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES) 114 if directives.set_loop_options not in loop_directives: 115 return gast.Dict([], []) 116 117 opts_dict = loop_directives[directives.set_loop_options] 118 str_keys, values = zip(*opts_dict.items()) 119 keys = [gast.Constant(s, kind=None) for s in str_keys] 120 values = list(values) # ast and gast don't play well with tuples. 121 return gast.Dict(keys, values) 122 123 def _create_undefined_assigns(self, undefined_symbols): 124 assignments = [] 125 for s in undefined_symbols: 126 template = ''' 127 var = ag__.Undefined(symbol_name) 128 ''' 129 assignments += templates.replace( 130 template, 131 var=s, 132 symbol_name=gast.Constant(s.ssf(), kind=None)) 133 return assignments 134 135 def _get_block_basic_vars(self, modified, live_in, live_out): 136 nonlocals = self.state[_Function].scope.nonlocals 137 basic_scope_vars = [] 138 for s in modified: 139 if s.is_composite(): 140 # TODO(mdan): Raise an error when this happens for a TF scope. 141 continue 142 # Variables not live into or out of the scope are considered local to the 143 # scope. 144 if s in live_in or s in live_out or s in nonlocals: 145 basic_scope_vars.append(s) 146 continue 147 return frozenset(basic_scope_vars) 148 149 def _get_block_composite_vars(self, modified, live_in): 150 # The scope variables corresponding to composite symbols (e.g. `self.x`). 151 composite_scope_vars = [] 152 for s in modified: 153 if not s.is_composite(): 154 continue 155 # Mutations made to objects created inside the scope will appear as writes 156 # to composite symbols. Because these mutations appear as modifications 157 # made to composite symbols, we check whether the composite's parent is 158 # actually live into the scope. 159 # Example: 160 # while cond: 161 # x = Foo() 162 # x.foo = 2 * x.foo # x.foo is live into the scope, but x is not. 163 # 164 # Note that some parents might not be symbols - for example, in x['foo'], 165 # 'foo' is a parent, but it's a literal, not a symbol. We don't check the 166 # liveness of literals. 167 support_set_symbols = tuple( 168 sss for sss in s.support_set if sss.is_symbol()) 169 if not all(sss in live_in for sss in support_set_symbols): 170 continue 171 composite_scope_vars.append(s) 172 return frozenset(composite_scope_vars) 173 174 def _get_block_vars(self, node, modified): 175 """Determines the variables affected inside a control flow statement.""" 176 defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) 177 live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) 178 live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) 179 fn_scope = self.state[_Function].scope 180 181 basic_scope_vars = self._get_block_basic_vars( 182 modified, 183 live_in, 184 live_out) 185 composite_scope_vars = self._get_block_composite_vars(modified, live_in) 186 scope_vars = tuple(basic_scope_vars | composite_scope_vars) 187 188 # Variables that are modified inside the scope, but not defined 189 # before entering it. Only simple variables must be defined. The 190 # composite ones will be implicitly checked at runtime. 191 possibly_undefined = ( 192 modified - defined_in - fn_scope.globals - fn_scope.nonlocals) 193 undefined = tuple(v for v in possibly_undefined if not v.is_composite()) 194 195 # Variables that are modified inside the scope, and depend on values outside 196 # it. 197 input_only = basic_scope_vars & live_in - live_out 198 199 # Place the outputs first, then sort lexicographically. 200 scope_vars = sorted(scope_vars, key=lambda v: (v in input_only, v)) 201 nouts = len(scope_vars) - len(input_only) 202 203 return scope_vars, undefined, nouts 204 205 def visit_If(self, node): 206 node = self.generic_visit(node) 207 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 208 orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) 209 210 cond_vars, undefined, nouts = self._get_block_vars( 211 node, body_scope.bound | orelse_scope.bound) 212 213 undefined_assigns = self._create_undefined_assigns(undefined) 214 215 nonlocal_declarations = self._create_nonlocal_declarations(cond_vars) 216 217 reserved = body_scope.referenced | orelse_scope.referenced 218 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) 219 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) 220 state_functions = self._create_state_functions( 221 cond_vars, nonlocal_declarations, state_getter_name, state_setter_name) 222 223 orelse_body = node.orelse 224 if not orelse_body: 225 orelse_body = [gast.Pass()] 226 227 template = """ 228 state_functions 229 def body_name(): 230 nonlocal_declarations 231 body 232 def orelse_name(): 233 nonlocal_declarations 234 orelse 235 undefined_assigns 236 ag__.if_stmt( 237 test, 238 body_name, 239 orelse_name, 240 state_getter_name, 241 state_setter_name, 242 (symbol_names,), 243 nouts) 244 """ 245 new_nodes = templates.replace( 246 template, 247 body=node.body, 248 body_name=self.ctx.namer.new_symbol('if_body', reserved), 249 orelse=orelse_body, 250 orelse_name=self.ctx.namer.new_symbol('else_body', reserved), 251 nonlocal_declarations=nonlocal_declarations, 252 nouts=gast.Constant(nouts, kind=None), 253 state_functions=state_functions, 254 state_getter_name=state_getter_name, 255 state_setter_name=state_setter_name, 256 symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars), 257 test=node.test, 258 undefined_assigns=undefined_assigns) 259 origin_info.copy_origin(node, new_nodes[-1]) 260 return new_nodes 261 262 def visit_While(self, node): 263 node = self.generic_visit(node) 264 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 265 266 loop_vars, undefined, _ = self._get_block_vars(node, body_scope.bound) 267 268 undefined_assigns = self._create_undefined_assigns(undefined) 269 270 nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) 271 272 reserved = body_scope.referenced 273 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) 274 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) 275 state_functions = self._create_state_functions( 276 loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) 277 278 opts = self._create_loop_options(node) 279 280 template = """ 281 state_functions 282 def body_name(): 283 nonlocal_declarations 284 body 285 def test_name(): 286 return test 287 undefined_assigns 288 ag__.while_stmt( 289 test_name, 290 body_name, 291 state_getter_name, 292 state_setter_name, 293 (symbol_names,), 294 opts) 295 """ 296 new_nodes = templates.replace( 297 template, 298 body=node.body, 299 body_name=self.ctx.namer.new_symbol('loop_body', reserved), 300 nonlocal_declarations=nonlocal_declarations, 301 opts=opts, 302 state_functions=state_functions, 303 state_getter_name=state_getter_name, 304 state_setter_name=state_setter_name, 305 symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars), 306 test=node.test, 307 test_name=self.ctx.namer.new_symbol('loop_test', reserved), 308 undefined_assigns=undefined_assigns) 309 origin_info.copy_origin(node, new_nodes[-1]) 310 return new_nodes 311 312 def visit_For(self, node): 313 node = self.generic_visit(node) 314 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 315 iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE) 316 317 loop_vars, undefined, _ = self._get_block_vars( 318 node, body_scope.bound | iter_scope.bound) 319 320 undefined_assigns = self._create_undefined_assigns(undefined) 321 322 nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) 323 324 reserved = body_scope.referenced | iter_scope.referenced 325 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) 326 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) 327 state_functions = self._create_state_functions( 328 loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) 329 330 opts = self._create_loop_options(node) 331 opts.keys.append(gast.Constant('iterate_names', kind=None)) 332 opts.values.append(gast.Constant( 333 parser.unparse(node.target, include_encoding_marker=False), kind=None)) 334 335 if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): 336 extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) 337 extra_test_name = self.ctx.namer.new_symbol( 338 'extra_test', reserved) 339 template = """ 340 def extra_test_name(): 341 nonlocal_declarations 342 return extra_test_expr 343 """ 344 extra_test_function = templates.replace( 345 template, 346 extra_test_expr=extra_test, 347 extra_test_name=extra_test_name, 348 loop_vars=loop_vars, 349 nonlocal_declarations=nonlocal_declarations) 350 else: 351 extra_test_name = parser.parse_expression('None') 352 extra_test_function = [] 353 354 # iterate_arg_name holds a single arg with the iterates, which may be a 355 # tuple. 356 iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved) 357 template = """ 358 iterates = iterate_arg_name 359 """ 360 iterate_expansion = templates.replace( 361 template, iterate_arg_name=iterate_arg_name, iterates=node.target) 362 origin_info.copy_origin(node, iterate_expansion) 363 364 template = """ 365 state_functions 366 def body_name(iterate_arg_name): 367 nonlocal_declarations 368 iterate_expansion 369 body 370 extra_test_function 371 undefined_assigns 372 ag__.for_stmt( 373 iterated, 374 extra_test_name, 375 body_name, 376 state_getter_name, 377 state_setter_name, 378 (symbol_names,), 379 opts) 380 """ 381 new_nodes = templates.replace( 382 template, 383 body=node.body, 384 body_name=self.ctx.namer.new_symbol('loop_body', reserved), 385 extra_test_function=extra_test_function, 386 extra_test_name=extra_test_name, 387 iterate_arg_name=iterate_arg_name, 388 iterate_expansion=iterate_expansion, 389 iterated=node.iter, 390 nonlocal_declarations=nonlocal_declarations, 391 opts=opts, 392 symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars), 393 state_functions=state_functions, 394 state_getter_name=state_getter_name, 395 state_setter_name=state_setter_name, 396 undefined_assigns=undefined_assigns) 397 origin_info.copy_origin(node, new_nodes[-1]) 398 return new_nodes 399 400 401class AnnotatedDef(reaching_definitions.Definition): 402 403 def __init__(self): 404 super(AnnotatedDef, self).__init__() 405 self.directives = {} 406 407 408def transform(node, ctx): 409 graphs = cfg.build(node) 410 node = qual_names.resolve(node) 411 node = activity.resolve(node, ctx, None) 412 node = reaching_definitions.resolve(node, ctx, graphs) 413 node = reaching_fndefs.resolve(node, ctx, graphs) 414 node = liveness.resolve(node, ctx, graphs) 415 416 node = ControlFlowTransformer(ctx).visit(node) 417 return node 418