1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Canonicalizes functions with multiple returns to use just one.""" 16 17import gast 18 19from tensorflow.python.autograph.core import converter 20from tensorflow.python.autograph.pyct import anno 21from tensorflow.python.autograph.pyct import parser 22from tensorflow.python.autograph.pyct import qual_names 23from tensorflow.python.autograph.pyct import templates 24from tensorflow.python.autograph.pyct.static_analysis import activity 25from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno 26 27 28BODY_DEFINITELY_RETURNS = 'BODY_DEFINITELY_RETURNS' 29ORELSE_DEFINITELY_RETURNS = 'ORELSE_DEFINITELY_RETURNS' 30STMT_DEFINITELY_RETURNS = 'STMT_DEFINITELY_RETURNS' 31 32 33class _RewriteBlock(object): 34 35 def __init__(self): 36 self.definitely_returns = False 37 38 39class ConditionalReturnRewriter(converter.Base): 40 """Rewrites a pattern where it's unobvious that all paths return a value. 41 42 This rewrite allows avoiding intermediate None return values. 43 44 The following pattern: 45 46 if cond: 47 <block 1> 48 return 49 else: 50 <block 2> 51 <block 3> 52 53 is converted to: 54 55 if cond: 56 <block 1> 57 return 58 else: 59 <block 2> 60 <block 3> 61 62 and vice-versa (if the else returns, subsequent statements are moved under the 63 if branch). 64 """ 65 66 def visit_Return(self, node): 67 self.state[_RewriteBlock].definitely_returns = True 68 return node 69 70 def _postprocess_statement(self, node): 71 # If the node definitely returns (e.g. it's a with statement with a 72 # return statement in it), then the current block also definitely returns. 73 if anno.getanno(node, STMT_DEFINITELY_RETURNS, default=False): 74 self.state[_RewriteBlock].definitely_returns = True 75 76 # The special case: collapse a typical conditional return pattern into 77 # a single conditional with possibly returns on both branches. This 78 # reduces the use of None return values, which don't work with TF 79 # conditionals. 80 if (isinstance(node, gast.If) 81 and anno.getanno(node, BODY_DEFINITELY_RETURNS, default=False)): 82 return node, node.orelse 83 elif (isinstance(node, gast.If) 84 and anno.getanno(node, ORELSE_DEFINITELY_RETURNS, default=False)): 85 return node, node.body 86 87 return node, None 88 89 def _visit_statement_block(self, node, nodes): 90 self.state[_RewriteBlock].enter() 91 new_nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) 92 block_definitely_returns = self.state[_RewriteBlock].definitely_returns 93 self.state[_RewriteBlock].exit() 94 return new_nodes, block_definitely_returns 95 96 def visit_While(self, node): 97 node.test = self.visit(node.test) 98 node.body, _ = self._visit_statement_block(node, node.body) 99 node.orelse, _ = self._visit_statement_block(node, node.orelse) 100 return node 101 102 def visit_For(self, node): 103 node.iter = self.visit(node.iter) 104 node.target = self.visit(node.target) 105 node.body, _ = self._visit_statement_block(node, node.body) 106 node.orelse, _ = self._visit_statement_block(node, node.orelse) 107 return node 108 109 def visit_With(self, node): 110 node.items = self.visit_block(node.items) 111 node.body, definitely_returns = self._visit_statement_block(node, node.body) 112 if definitely_returns: 113 anno.setanno(node, STMT_DEFINITELY_RETURNS, True) 114 return node 115 116 def visit_Try(self, node): 117 # We could decide whether a 'try' DEFINITELY_RETURNS based on its components 118 # It is not clear whether we want to do anything with this given 119 # a 'try' is likely to throw an exception in some circumstances. 120 node.body, _ = self._visit_statement_block(node, node.body) 121 node.orelse, _ = self._visit_statement_block(node, node.orelse) 122 node.finalbody, _ = self._visit_statement_block(node, node.finalbody) 123 node.handlers = self.visit_block(node.handlers) 124 return node 125 126 def visit_ExceptHandler(self, node): 127 # To determine whether `try` DEFINITELY_RETURNS we need to revisit this. 128 node.body, _ = self._visit_statement_block(node, node.body) 129 return node 130 131 def visit_If(self, node): 132 node.test = self.visit(node.test) 133 134 node.body, body_definitely_returns = self._visit_statement_block( 135 node, node.body) 136 if body_definitely_returns: 137 anno.setanno(node, BODY_DEFINITELY_RETURNS, True) 138 139 node.orelse, orelse_definitely_returns = self._visit_statement_block( 140 node, node.orelse) 141 if orelse_definitely_returns: 142 anno.setanno(node, ORELSE_DEFINITELY_RETURNS, True) 143 144 if body_definitely_returns and orelse_definitely_returns: 145 self.state[_RewriteBlock].definitely_returns = True 146 147 return node 148 149 def visit_FunctionDef(self, node): 150 node.args = self.visit(node.args) 151 node.body, _ = self._visit_statement_block(node, node.body) 152 return node 153 154 155class _Block(object): 156 157 def __init__(self): 158 self.is_function = False 159 self.return_used = False 160 self.create_guard_next = False 161 self.create_guard_now = False 162 163 def __repr__(self): 164 return 'used: {}'.format( 165 self.return_used) 166 167 168class _Function(object): 169 170 def __init__(self): 171 self.do_return_var_name = None 172 self.retval_var_name = None 173 174 def __repr__(self): 175 return 'return control: {}, return value: {}'.format( 176 self.do_return_var_name, self.retval_var_name) 177 178 179class ReturnStatementsTransformer(converter.Base): 180 """Lowers return statements into variables and conditionals. 181 182 Specifically, the following pattern: 183 184 <block 1> 185 return val 186 <block 2> 187 188 is converted to: 189 190 do_return = False 191 retval = None 192 193 <block 1> 194 195 do_return = True 196 retval = val 197 198 if not do_return: 199 <block 2> 200 201 return retval 202 203 The conversion adjusts loops as well: 204 205 <block 1> 206 while cond: 207 <block 2> 208 return retval 209 210 is converted to: 211 212 <block 1> 213 while not do_return and cond: 214 <block 2> 215 do_return = True 216 retval = val 217 """ 218 219 def __init__(self, ctx, allow_missing_return): 220 super(ReturnStatementsTransformer, self).__init__(ctx) 221 self.allow_missing_return = allow_missing_return 222 223 def visit_Return(self, node): 224 for block in reversed(self.state[_Block].stack): 225 block.return_used = True 226 block.create_guard_next = True 227 if block.is_function: 228 break 229 230 retval = node.value if node.value else parser.parse_expression('None') 231 232 # Note: If `return <expr> raises, then the return is aborted. 233 # The try-catch below ensures the variables remain consistent in that case. 234 template = """ 235 try: 236 do_return_var_name = True 237 retval_var_name = retval 238 except: 239 do_return_var_name = False 240 raise 241 """ 242 node = templates.replace( 243 template, 244 do_return_var_name=self.state[_Function].do_return_var_name, 245 retval_var_name=self.state[_Function].retval_var_name, 246 retval=retval) 247 248 return node 249 250 def _postprocess_statement(self, node): 251 if not self.state[_Block].return_used: 252 return node, None 253 254 state = self.state[_Block] 255 if state.create_guard_now: 256 template = """ 257 if not do_return_var_name: 258 original_node 259 """ 260 cond, = templates.replace( 261 template, 262 do_return_var_name=self.state[_Function].do_return_var_name, 263 original_node=node) 264 node, block = cond, cond.body 265 else: 266 node, block = node, None 267 268 state.create_guard_now = state.create_guard_next 269 state.create_guard_next = False 270 271 return node, block 272 273 def _visit_statement_block(self, node, nodes): 274 self.state[_Block].enter() 275 nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) 276 self.state[_Block].exit() 277 return nodes 278 279 def visit_While(self, node): 280 node.test = self.visit(node.test) 281 282 # Add the check for return to the loop condition. 283 node.body = self._visit_statement_block(node, node.body) 284 if self.state[_Block].return_used: 285 node.test = templates.replace_as_expression( 286 'not control_var and test', 287 test=node.test, 288 control_var=self.state[_Function].do_return_var_name) 289 290 node.orelse = self._visit_statement_block(node, node.orelse) 291 return node 292 293 def visit_For(self, node): 294 node.iter = self.visit(node.iter) 295 node.target = self.visit(node.target) 296 297 # Add the check for return to the loop condition. 298 node.body = self._visit_statement_block(node, node.body) 299 if self.state[_Block].return_used: 300 extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None) 301 if extra_test is not None: 302 extra_test = templates.replace_as_expression( 303 'not control_var and extra_test', 304 extra_test=extra_test, 305 control_var=self.state[_Function].do_return_var_name) 306 else: 307 extra_test = templates.replace_as_expression( 308 'not control_var', 309 control_var=self.state[_Function].do_return_var_name) 310 anno.setanno(node, anno.Basic.EXTRA_LOOP_TEST, extra_test) 311 312 node.orelse = self._visit_statement_block(node, node.orelse) 313 return node 314 315 def visit_With(self, node): 316 node.items = self.visit_block(node.items) 317 node.body = self._visit_statement_block(node, node.body) 318 return node 319 320 def visit_Try(self, node): 321 node.body = self._visit_statement_block(node, node.body) 322 node.orelse = self._visit_statement_block(node, node.orelse) 323 node.finalbody = self._visit_statement_block(node, node.finalbody) 324 node.handlers = self.visit_block(node.handlers) 325 return node 326 327 def visit_ExceptHandler(self, node): 328 node.body = self._visit_statement_block(node, node.body) 329 return node 330 331 def visit_If(self, node): 332 node.test = self.visit(node.test) 333 node.body = self._visit_statement_block(node, node.body) 334 node.orelse = self._visit_statement_block(node, node.orelse) 335 return node 336 337 def visit_FunctionDef(self, node): 338 with self.state[_Function] as fn: 339 with self.state[_Block] as block: 340 block.is_function = True 341 342 scope = anno.getanno(node, NodeAnno.BODY_SCOPE) 343 do_return_var_name = self.ctx.namer.new_symbol('do_return', 344 scope.referenced) 345 retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced) 346 fn.do_return_var_name = do_return_var_name 347 fn.retval_var_name = retval_var_name 348 349 node.body = self._visit_statement_block(node, node.body) 350 351 if block.return_used: 352 353 if self.allow_missing_return: 354 # The function would have a single `with` node that wraps the 355 # entire body. If the function had a docstring, the body has two 356 # nodes, with the `with` as the second node. 357 wrapper_node = node.body[-1] 358 assert isinstance(wrapper_node, gast.With), ( 359 'This transformer requires the functions converter.') 360 361 template = """ 362 do_return_var_name = False 363 retval_var_name = ag__.UndefinedReturnValue() 364 body 365 return function_context.ret(retval_var_name, do_return_var_name) 366 """ 367 368 wrapper_node.body = templates.replace( 369 template, 370 body=wrapper_node.body, 371 do_return_var_name=do_return_var_name, 372 function_context=anno.getanno(node, 'function_context_name'), 373 retval_var_name=retval_var_name) 374 else: 375 template = """ 376 body 377 return retval_var_name 378 """ 379 node.body = templates.replace( 380 template, 381 body=node.body, 382 do_return_var_name=do_return_var_name, 383 retval_var_name=retval_var_name) 384 385 return node 386 387 388def transform(node, ctx, default_to_null_return=True): 389 """Ensure a function has only a single return, at the end.""" 390 node = qual_names.resolve(node) 391 node = activity.resolve(node, ctx, None) 392 393 # Note: Technically, these two could be merged into a single walk, but 394 # keeping them separate helps with readability. 395 node = ConditionalReturnRewriter(ctx).visit(node) 396 397 node = qual_names.resolve(node) 398 node = activity.resolve(node, ctx, None) 399 transformer = ReturnStatementsTransformer( 400 ctx, allow_missing_return=default_to_null_return) 401 node = transformer.visit(node) 402 return node 403