1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Handles control flow statements: while, for, if.""" 16 17import gast 18 19from tensorflow.python.autograph.core import converter 20from tensorflow.python.autograph.lang import directives 21from tensorflow.python.autograph.pyct import anno 22from tensorflow.python.autograph.pyct import cfg 23from tensorflow.python.autograph.pyct import origin_info 24from tensorflow.python.autograph.pyct import parser 25from tensorflow.python.autograph.pyct import qual_names 26from tensorflow.python.autograph.pyct import templates 27from tensorflow.python.autograph.pyct.static_analysis import activity 28from tensorflow.python.autograph.pyct.static_analysis import annos 29from tensorflow.python.autograph.pyct.static_analysis import liveness 30from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 31from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs 32 33 34class _Function(object): 35 36 scope = None 37 38 39class ControlFlowTransformer(converter.Base): 40 """Transforms control flow structures like loops an conditionals.""" 41 42 def visit_Lambda(self, node): 43 with self.state[_Function] as fn: 44 fn.scope = anno.getanno(node, anno.Static.SCOPE) 45 return self.generic_visit(node) 46 47 def visit_FunctionDef(self, node): 48 with self.state[_Function] as fn: 49 fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 50 return self.generic_visit(node) 51 52 def _create_nonlocal_declarations(self, vars_): 53 vars_ = set(vars_) 54 results = [] 55 global_vars = self.state[_Function].scope.globals & vars_ 56 57 if global_vars: 58 results.append(gast.Global([str(v) for v in global_vars])) 59 60 nonlocal_vars = [ 61 v for v in vars_ if not v.is_composite() and v not in global_vars] 62 if nonlocal_vars: 63 results.append(gast.Nonlocal([str(v) for v in nonlocal_vars])) 64 65 return results 66 67 def _create_state_functions( 68 self, block_vars, nonlocal_declarations, getter_name, setter_name): 69 if not block_vars: 70 template = """ 71 def getter_name(): 72 return () 73 def setter_name(block_vars): 74 pass 75 """ 76 return templates.replace( 77 template, getter_name=getter_name, setter_name=setter_name) 78 79 guarded_block_vars = [] 80 for v in block_vars: 81 if v.is_simple(): 82 guarded_block_vars.append(v) 83 else: 84 guarded_block_vars.append( 85 templates.replace_as_expression( 86 'ag__.ldu(lambda: var_, name)', 87 var_=v, 88 name=gast.Constant(str(v), kind=None))) 89 90 template = """ 91 def getter_name(): 92 return guarded_state_vars, 93 def setter_name(vars_): 94 nonlocal_declarations 95 state_vars, = vars_ 96 """ 97 return templates.replace( 98 template, 99 nonlocal_declarations=nonlocal_declarations, 100 getter_name=getter_name, 101 guarded_state_vars=guarded_block_vars, 102 setter_name=setter_name, 103 state_vars=tuple(block_vars)) 104 105 def _create_loop_options(self, node): 106 if not anno.hasanno(node, anno.Basic.DIRECTIVES): 107 return gast.Dict([], []) 108 109 loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES) 110 if directives.set_loop_options not in loop_directives: 111 return gast.Dict([], []) 112 113 opts_dict = loop_directives[directives.set_loop_options] 114 str_keys, values = zip(*opts_dict.items()) 115 keys = [gast.Constant(s, kind=None) for s in str_keys] 116 values = list(values) # ast and gast don't play well with tuples. 117 return gast.Dict(keys, values) 118 119 def _create_undefined_assigns(self, undefined_symbols): 120 assignments = [] 121 for s in undefined_symbols: 122 template = ''' 123 var = ag__.Undefined(symbol_name) 124 ''' 125 assignments += templates.replace( 126 template, 127 var=s, 128 symbol_name=gast.Constant(s.ssf(), kind=None)) 129 return assignments 130 131 def _get_block_basic_vars(self, modified, live_in, live_out): 132 nonlocals = self.state[_Function].scope.nonlocals 133 basic_scope_vars = [] 134 for s in modified: 135 if s.is_composite(): 136 # TODO(mdan): Raise an error when this happens for a TF scope. 137 continue 138 # Variables not live into or out of the scope are considered local to the 139 # scope. 140 if s in live_in or s in live_out or s in nonlocals: 141 basic_scope_vars.append(s) 142 continue 143 return frozenset(basic_scope_vars) 144 145 def _get_block_composite_vars(self, modified, live_in): 146 # The scope variables corresponding to composite symbols (e.g. `self.x`). 147 composite_scope_vars = [] 148 for s in modified: 149 if not s.is_composite(): 150 continue 151 # Mutations made to objects created inside the scope will appear as writes 152 # to composite symbols. Because these mutations appear as modifications 153 # made to composite symbols, we check whether the composite's parent is 154 # actually live into the scope. 155 # Example: 156 # while cond: 157 # x = Foo() 158 # x.foo = 2 * x.foo # x.foo is live into the scope, but x is not. 159 # 160 # Note that some parents might not be symbols - for example, in x['foo'], 161 # 'foo' is a parent, but it's a literal, not a symbol. We don't check the 162 # liveness of literals. 163 support_set_symbols = tuple( 164 sss for sss in s.support_set if sss.is_symbol()) 165 if not all(sss in live_in for sss in support_set_symbols): 166 continue 167 composite_scope_vars.append(s) 168 return frozenset(composite_scope_vars) 169 170 def _get_block_vars(self, node, modified): 171 """Determines the variables affected inside a control flow statement.""" 172 defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) 173 live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) 174 live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) 175 fn_scope = self.state[_Function].scope 176 177 basic_scope_vars = self._get_block_basic_vars( 178 modified, 179 live_in, 180 live_out) 181 composite_scope_vars = self._get_block_composite_vars(modified, live_in) 182 scope_vars = tuple(basic_scope_vars | composite_scope_vars) 183 184 # Variables that are modified inside the scope, but not defined 185 # before entering it. Only simple variables must be defined. The 186 # composite ones will be implicitly checked at runtime. 187 possibly_undefined = ( 188 modified - defined_in - fn_scope.globals - fn_scope.nonlocals) 189 undefined = tuple(v for v in possibly_undefined if not v.is_composite()) 190 191 # Variables that are modified inside the scope, and depend on values outside 192 # it. 193 input_only = basic_scope_vars & live_in - live_out 194 195 # Place the outputs first, then sort lexicographically. 196 scope_vars = sorted(scope_vars, key=lambda v: (v in input_only, v)) 197 nouts = len(scope_vars) - len(input_only) 198 199 return scope_vars, undefined, nouts 200 201 def visit_If(self, node): 202 node = self.generic_visit(node) 203 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 204 orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) 205 206 cond_vars, undefined, nouts = self._get_block_vars( 207 node, body_scope.bound | orelse_scope.bound) 208 209 undefined_assigns = self._create_undefined_assigns(undefined) 210 211 nonlocal_declarations = self._create_nonlocal_declarations(cond_vars) 212 213 reserved = body_scope.referenced | orelse_scope.referenced 214 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) 215 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) 216 state_functions = self._create_state_functions( 217 cond_vars, nonlocal_declarations, state_getter_name, state_setter_name) 218 219 orelse_body = node.orelse 220 if not orelse_body: 221 orelse_body = [gast.Pass()] 222 223 template = """ 224 state_functions 225 def body_name(): 226 nonlocal_declarations 227 body 228 def orelse_name(): 229 nonlocal_declarations 230 orelse 231 undefined_assigns 232 ag__.if_stmt( 233 test, 234 body_name, 235 orelse_name, 236 state_getter_name, 237 state_setter_name, 238 (symbol_names,), 239 nouts) 240 """ 241 new_nodes = templates.replace( 242 template, 243 body=node.body, 244 body_name=self.ctx.namer.new_symbol('if_body', reserved), 245 orelse=orelse_body, 246 orelse_name=self.ctx.namer.new_symbol('else_body', reserved), 247 nonlocal_declarations=nonlocal_declarations, 248 nouts=gast.Constant(nouts, kind=None), 249 state_functions=state_functions, 250 state_getter_name=state_getter_name, 251 state_setter_name=state_setter_name, 252 symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars), 253 test=node.test, 254 undefined_assigns=undefined_assigns) 255 origin_info.copy_origin(node, new_nodes[-1]) 256 return new_nodes 257 258 def visit_While(self, node): 259 node = self.generic_visit(node) 260 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 261 262 loop_vars, undefined, _ = self._get_block_vars(node, body_scope.bound) 263 264 undefined_assigns = self._create_undefined_assigns(undefined) 265 266 nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) 267 268 reserved = body_scope.referenced 269 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) 270 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) 271 state_functions = self._create_state_functions( 272 loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) 273 274 opts = self._create_loop_options(node) 275 276 template = """ 277 state_functions 278 def body_name(): 279 nonlocal_declarations 280 body 281 def test_name(): 282 return test 283 undefined_assigns 284 ag__.while_stmt( 285 test_name, 286 body_name, 287 state_getter_name, 288 state_setter_name, 289 (symbol_names,), 290 opts) 291 """ 292 new_nodes = templates.replace( 293 template, 294 body=node.body, 295 body_name=self.ctx.namer.new_symbol('loop_body', reserved), 296 nonlocal_declarations=nonlocal_declarations, 297 opts=opts, 298 state_functions=state_functions, 299 state_getter_name=state_getter_name, 300 state_setter_name=state_setter_name, 301 symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars), 302 test=node.test, 303 test_name=self.ctx.namer.new_symbol('loop_test', reserved), 304 undefined_assigns=undefined_assigns) 305 origin_info.copy_origin(node, new_nodes[-1]) 306 return new_nodes 307 308 def visit_For(self, node): 309 node = self.generic_visit(node) 310 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 311 iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE) 312 313 loop_vars, undefined, _ = self._get_block_vars( 314 node, body_scope.bound | iter_scope.bound) 315 316 undefined_assigns = self._create_undefined_assigns(undefined) 317 318 nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) 319 320 reserved = body_scope.referenced | iter_scope.referenced 321 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) 322 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) 323 state_functions = self._create_state_functions( 324 loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) 325 326 opts = self._create_loop_options(node) 327 opts.keys.append(gast.Constant('iterate_names', kind=None)) 328 opts.values.append(gast.Constant( 329 parser.unparse(node.target, include_encoding_marker=False), kind=None)) 330 331 if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): 332 extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) 333 extra_test_name = self.ctx.namer.new_symbol( 334 'extra_test', reserved) 335 template = """ 336 def extra_test_name(): 337 nonlocal_declarations 338 return extra_test_expr 339 """ 340 extra_test_function = templates.replace( 341 template, 342 extra_test_expr=extra_test, 343 extra_test_name=extra_test_name, 344 loop_vars=loop_vars, 345 nonlocal_declarations=nonlocal_declarations) 346 else: 347 extra_test_name = parser.parse_expression('None') 348 extra_test_function = [] 349 350 # iterate_arg_name holds a single arg with the iterates, which may be a 351 # tuple. 352 iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved) 353 template = """ 354 iterates = iterate_arg_name 355 """ 356 iterate_expansion = templates.replace( 357 template, iterate_arg_name=iterate_arg_name, iterates=node.target) 358 origin_info.copy_origin(node, iterate_expansion) 359 360 template = """ 361 state_functions 362 def body_name(iterate_arg_name): 363 nonlocal_declarations 364 iterate_expansion 365 body 366 extra_test_function 367 undefined_assigns 368 ag__.for_stmt( 369 iterated, 370 extra_test_name, 371 body_name, 372 state_getter_name, 373 state_setter_name, 374 (symbol_names,), 375 opts) 376 """ 377 new_nodes = templates.replace( 378 template, 379 body=node.body, 380 body_name=self.ctx.namer.new_symbol('loop_body', reserved), 381 extra_test_function=extra_test_function, 382 extra_test_name=extra_test_name, 383 iterate_arg_name=iterate_arg_name, 384 iterate_expansion=iterate_expansion, 385 iterated=node.iter, 386 nonlocal_declarations=nonlocal_declarations, 387 opts=opts, 388 symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars), 389 state_functions=state_functions, 390 state_getter_name=state_getter_name, 391 state_setter_name=state_setter_name, 392 undefined_assigns=undefined_assigns) 393 origin_info.copy_origin(node, new_nodes[-1]) 394 return new_nodes 395 396 397class AnnotatedDef(reaching_definitions.Definition): 398 399 def __init__(self): 400 super(AnnotatedDef, self).__init__() 401 self.directives = {} 402 403 404def transform(node, ctx): 405 graphs = cfg.build(node) 406 node = qual_names.resolve(node) 407 node = activity.resolve(node, ctx, None) 408 node = reaching_definitions.resolve(node, ctx, graphs) 409 node = reaching_fndefs.resolve(node, ctx, graphs) 410 node = liveness.resolve(node, ctx, graphs) 411 412 node = ControlFlowTransformer(ctx).visit(node) 413 return node 414