1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Canonicalizes continue statements by de-sugaring into a control boolean.""" 16 17from tensorflow.python.autograph.core import converter 18from tensorflow.python.autograph.pyct import anno 19from tensorflow.python.autograph.pyct import qual_names 20from tensorflow.python.autograph.pyct import templates 21from tensorflow.python.autograph.pyct.static_analysis import activity 22from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno 23 24 25class _Continue(object): 26 27 def __init__(self): 28 self.used = False 29 self.control_var_name = None 30 31 def __repr__(self): 32 return '<_Continue(used: {}, var: {})>'.format(self.used, 33 self.control_var_name) 34 35 36class _Block(object): 37 """Tracks information about lexical blocks as they are visited in the AST. 38 39 Mainly, this object tracks the creation of block guards that replace 40 `continue` statements (e.g. `if not continue_:`). 41 42 Attributes: 43 create_guard_current: bool, whether to create a guard for the current 44 statement. 45 create_guard_next: bool, whether to create a guard for the next 46 statement. 47 is_loop_type: bool, whether this block is the body of a loop. 48 """ 49 50 def __init__(self): 51 self.is_loop_type = False 52 self.create_guard_current = False 53 self.create_guard_next = False 54 55 56class ContinueCanonicalizationTransformer(converter.Base): 57 """Canonicalizes continue statements into additional conditionals.""" 58 59 def visit_Continue(self, node): 60 self.state[_Continue].used = True 61 for block in reversed(self.state[_Block].stack): 62 # See ContinueCanonicalizationTest.test_multiple_continues for an example 63 # it's necessary to create guards for all enclosing affected blocks, not 64 # just that of the current block. 65 block.create_guard_next = True 66 if block.is_loop_type: 67 # continue only affects the innermost loop 68 break 69 template = """ 70 var_name = True 71 """ 72 return templates.replace( 73 template, var_name=self.state[_Continue].control_var_name) 74 75 def _postprocess_statement(self, node): 76 if self.state[_Continue].used: 77 block = self.state[_Block] 78 should_wrap_current = block.create_guard_current 79 # After processing propagate whether to guard the next statement 80 block.create_guard_current = block.create_guard_next 81 block.create_guard_next = False 82 if should_wrap_current: 83 template = """ 84 if not var_name: 85 original_node 86 """ 87 cond, = templates.replace( 88 template, 89 var_name=self.state[_Continue].control_var_name, 90 original_node=node) 91 return cond, cond.body 92 return node, None 93 94 def _visit_loop_body(self, node, nodes): 95 self.state[_Continue].enter() 96 self.state[_Block].enter() 97 self.state[_Block].is_loop_type = True 98 scope = anno.getanno(node, NodeAnno.BODY_SCOPE) 99 continue_var = self.ctx.namer.new_symbol('continue_', scope.referenced) 100 self.state[_Continue].control_var_name = continue_var 101 102 nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) 103 104 if self.state[_Continue].used: 105 template = """ 106 var_name = False 107 """ 108 control_var_init = templates.replace(template, var_name=continue_var) 109 nodes = control_var_init + nodes 110 111 self.state[_Block].exit() 112 self.state[_Continue].exit() 113 return nodes 114 115 def _visit_non_loop_body(self, nodes): 116 self.state[_Block].enter() 117 nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) 118 self.state[_Block].exit() 119 return nodes 120 121 def visit_While(self, node): 122 node.test = self.visit(node.test) 123 node.body = self._visit_loop_body(node, node.body) 124 # A continue in the else clause applies to the containing scope. 125 node.orelse = self._visit_non_loop_body(node.orelse) 126 return node 127 128 def visit_For(self, node): 129 node.target = self.generic_visit(node.target) 130 node.iter = self.generic_visit(node.iter) 131 node.body = self._visit_loop_body(node, node.body) 132 # A continue in the else clause applies to the containing scope. 133 node.orelse = self._visit_non_loop_body(node.orelse) 134 return node 135 136 def visit_If(self, node): 137 node.body = self._visit_non_loop_body(node.body) 138 node.orelse = self._visit_non_loop_body(node.orelse) 139 return node 140 141 def visit_With(self, node): 142 node.items = self.visit_block(node.items) 143 node.body = self._visit_non_loop_body(node.body) 144 return node 145 146 def visit_Try(self, node): 147 node.body = self._visit_non_loop_body(node.body) 148 node.orelse = self._visit_non_loop_body(node.orelse) 149 # In Python 3.8 and later continue is allowed in finally blocks 150 node.finalbody = self._visit_non_loop_body(node.finalbody) 151 node.handlers = self.visit_block(node.handlers) 152 return node 153 154 def visit_ExceptHandler(self, node): 155 node.body = self._visit_non_loop_body(node.body) 156 return node 157 158 159def transform(node, ctx): 160 node = qual_names.resolve(node) 161 node = activity.resolve(node, ctx, None) 162 163 node = ContinueCanonicalizationTransformer(ctx).visit(node) 164 return node 165