1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Lowers break statements to conditionals.""" 16 17from tensorflow.python.autograph.core import converter 18from tensorflow.python.autograph.pyct import anno 19from tensorflow.python.autograph.pyct import qual_names 20from tensorflow.python.autograph.pyct import templates 21from tensorflow.python.autograph.pyct.static_analysis import activity 22from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno 23 24 25class _Break(object): 26 27 def __init__(self): 28 self.used = False 29 self.control_var_name = None 30 31 def __repr__(self): 32 return 'used: %s, var: %s' % (self.used, self.control_var_name) 33 34 35class BreakTransformer(converter.Base): 36 """Canonicalizes break statements into additional conditionals.""" 37 38 def visit_Break(self, node): 39 self.state[_Break].used = True 40 var_name = self.state[_Break].control_var_name 41 # TODO(mdan): This will fail when expanded inside a top-level else block. 42 template = """ 43 var_name = True 44 continue 45 """ 46 return templates.replace(template, var_name=var_name) 47 48 def _guard_if_present(self, block, var_name): 49 """Prevents the block from executing if var_name is set.""" 50 if not block: 51 return block 52 53 template = """ 54 if not var_name: 55 block 56 """ 57 node = templates.replace( 58 template, 59 var_name=var_name, 60 block=block) 61 return node 62 63 def _process_body(self, nodes, break_var): 64 self.state[_Break].enter() 65 self.state[_Break].control_var_name = break_var 66 nodes = self.visit_block(nodes) 67 break_used = self.state[_Break].used 68 self.state[_Break].exit() 69 return nodes, break_used 70 71 def visit_While(self, node): 72 original_node = node 73 scope = anno.getanno(node, NodeAnno.BODY_SCOPE) 74 break_var = self.ctx.namer.new_symbol('break_', scope.referenced) 75 76 node.test = self.visit(node.test) 77 node.body, break_used = self._process_body(node.body, break_var) 78 # A break in the else clause applies to the containing scope. 79 node.orelse = self.visit_block(node.orelse) 80 81 if not break_used: 82 template = """ 83 while test: 84 body 85 orelse 86 """ 87 node = templates.replace( 88 template, test=node.test, body=node.body, orelse=node.orelse) 89 90 new_while_node = node[0] 91 anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) 92 93 return node 94 95 # Python's else clause only triggers if the loop exited cleanly (e.g. 96 # break did not trigger). 97 guarded_orelse = self._guard_if_present(node.orelse, break_var) 98 99 template = """ 100 var_name = False 101 while not var_name and test: 102 body 103 orelse 104 """ 105 node = templates.replace( 106 template, 107 var_name=break_var, 108 test=node.test, 109 body=node.body, 110 orelse=guarded_orelse) 111 112 new_while_node = node[1] 113 anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) 114 115 return node 116 117 def visit_For(self, node): 118 original_node = node 119 scope = anno.getanno(node, NodeAnno.BODY_SCOPE) 120 break_var = self.ctx.namer.new_symbol('break_', scope.referenced) 121 122 node.target = self.visit(node.target) 123 node.iter = self.visit(node.iter) 124 node.body, break_used = self._process_body(node.body, break_var) 125 # A break in the else clause applies to the containing scope. 126 node.orelse = self.visit_block(node.orelse) 127 128 if not break_used: 129 template = """ 130 for target in iter_: 131 body 132 orelse 133 """ 134 node = templates.replace( 135 template, 136 iter_=node.iter, 137 target=node.target, 138 body=node.body, 139 orelse=node.orelse) 140 141 new_for_node = node[0] 142 anno.copyanno(original_node, new_for_node, anno.Basic.EXTRA_LOOP_TEST) 143 anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) 144 145 return node 146 147 # Python's else clause only triggers if the loop exited cleanly (e.g. 148 # break did not trigger). 149 guarded_orelse = self._guard_if_present(node.orelse, break_var) 150 extra_test = templates.replace_as_expression( 151 'not var_name', var_name=break_var) 152 153 # The extra test is hidden in the AST, which will confuse the static 154 # analysis. To mitigate that, we insert a no-op statement that ensures 155 # the control variable is marked as used. 156 # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) 157 template = """ 158 var_name = False 159 for target in iter_: 160 (var_name,) 161 body 162 orelse 163 """ 164 node = templates.replace( 165 template, 166 var_name=break_var, 167 iter_=node.iter, 168 target=node.target, 169 body=node.body, 170 orelse=guarded_orelse) 171 172 new_for_node = node[1] 173 anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) 174 anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) 175 176 return node 177 178 179def transform(node, ctx): 180 node = qual_names.resolve(node) 181 node = activity.resolve(node, ctx, None) 182 183 transformer = BreakTransformer(ctx) 184 node = transformer.visit(node) 185 return node 186