• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Lowers break statements to conditionals."""
16
17from tensorflow.python.autograph.core import converter
18from tensorflow.python.autograph.pyct import anno
19from tensorflow.python.autograph.pyct import qual_names
20from tensorflow.python.autograph.pyct import templates
21from tensorflow.python.autograph.pyct.static_analysis import activity
22from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
23
24
25class _Break(object):
26
27  def __init__(self):
28    self.used = False
29    self.control_var_name = None
30
31  def __repr__(self):
32    return 'used: %s, var: %s' % (self.used, self.control_var_name)
33
34
35class BreakTransformer(converter.Base):
36  """Canonicalizes break statements into additional conditionals."""
37
38  def visit_Break(self, node):
39    self.state[_Break].used = True
40    var_name = self.state[_Break].control_var_name
41    # TODO(mdan): This will fail when expanded inside a top-level else block.
42    template = """
43      var_name = True
44      continue
45    """
46    return templates.replace(template, var_name=var_name)
47
48  def _guard_if_present(self, block, var_name):
49    """Prevents the block from executing if var_name is set."""
50    if not block:
51      return block
52
53    template = """
54        if not var_name:
55          block
56      """
57    node = templates.replace(
58        template,
59        var_name=var_name,
60        block=block)
61    return node
62
63  def _process_body(self, nodes, break_var):
64    self.state[_Break].enter()
65    self.state[_Break].control_var_name = break_var
66    nodes = self.visit_block(nodes)
67    break_used = self.state[_Break].used
68    self.state[_Break].exit()
69    return nodes, break_used
70
71  def visit_While(self, node):
72    original_node = node
73    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
74    break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
75
76    node.test = self.visit(node.test)
77    node.body, break_used = self._process_body(node.body, break_var)
78    # A break in the else clause applies to the containing scope.
79    node.orelse = self.visit_block(node.orelse)
80
81    if not break_used:
82      template = """
83        while test:
84          body
85        orelse
86      """
87      node = templates.replace(
88          template, test=node.test, body=node.body, orelse=node.orelse)
89
90      new_while_node = node[0]
91      anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)
92
93      return node
94
95    # Python's else clause only triggers if the loop exited cleanly (e.g.
96    # break did not trigger).
97    guarded_orelse = self._guard_if_present(node.orelse, break_var)
98
99    template = """
100      var_name = False
101      while not var_name and test:
102        body
103      orelse
104    """
105    node = templates.replace(
106        template,
107        var_name=break_var,
108        test=node.test,
109        body=node.body,
110        orelse=guarded_orelse)
111
112    new_while_node = node[1]
113    anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)
114
115    return node
116
117  def visit_For(self, node):
118    original_node = node
119    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
120    break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
121
122    node.target = self.visit(node.target)
123    node.iter = self.visit(node.iter)
124    node.body, break_used = self._process_body(node.body, break_var)
125    # A break in the else clause applies to the containing scope.
126    node.orelse = self.visit_block(node.orelse)
127
128    if not break_used:
129      template = """
130        for target in iter_:
131          body
132        orelse
133      """
134      node = templates.replace(
135          template,
136          iter_=node.iter,
137          target=node.target,
138          body=node.body,
139          orelse=node.orelse)
140
141      new_for_node = node[0]
142      anno.copyanno(original_node, new_for_node, anno.Basic.EXTRA_LOOP_TEST)
143      anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)
144
145      return node
146
147    # Python's else clause only triggers if the loop exited cleanly (e.g.
148    # break did not trigger).
149    guarded_orelse = self._guard_if_present(node.orelse, break_var)
150    extra_test = templates.replace_as_expression(
151        'not var_name', var_name=break_var)
152
153    # The extra test is hidden in the AST, which will confuse the static
154    # analysis. To mitigate that, we insert a no-op statement that ensures
155    # the control variable is marked as used.
156    # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
157    template = """
158      var_name = False
159      for target in iter_:
160        (var_name,)
161        body
162      orelse
163    """
164    node = templates.replace(
165        template,
166        var_name=break_var,
167        iter_=node.iter,
168        target=node.target,
169        body=node.body,
170        orelse=guarded_orelse)
171
172    new_for_node = node[1]
173    anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
174    anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)
175
176    return node
177
178
179def transform(node, ctx):
180  node = qual_names.resolve(node)
181  node = activity.resolve(node, ctx, None)
182
183  transformer = BreakTransformer(ctx)
184  node = transformer.visit(node)
185  return node
186