• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Canonicalizes continue statements by de-sugaring into a control boolean."""
16
17from tensorflow.python.autograph.core import converter
18from tensorflow.python.autograph.pyct import anno
19from tensorflow.python.autograph.pyct import qual_names
20from tensorflow.python.autograph.pyct import templates
21from tensorflow.python.autograph.pyct.static_analysis import activity
22from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
23
24
25class _Continue(object):
26
27  def __init__(self):
28    self.used = False
29    self.control_var_name = None
30
31  def __repr__(self):
32    return '<_Continue(used: {}, var: {})>'.format(self.used,
33                                                   self.control_var_name)
34
35
36class _Block(object):
37  """Tracks information about lexical blocks as they are visited in the AST.
38
39  Mainly, this object tracks the creation of block guards that replace
40  `continue` statements (e.g. `if not continue_:`).
41
42  Attributes:
43    create_guard_current: bool, whether to create a guard for the current
44      statement.
45    create_guard_next: bool, whether to create a guard for the next
46      statement.
47    is_loop_type: bool, whether this block is the body of a loop.
48  """
49
50  def __init__(self):
51    self.is_loop_type = False
52    self.create_guard_current = False
53    self.create_guard_next = False
54
55
56class ContinueCanonicalizationTransformer(converter.Base):
57  """Canonicalizes continue statements into additional conditionals."""
58
59  def visit_Continue(self, node):
60    self.state[_Continue].used = True
61    for block in reversed(self.state[_Block].stack):
62      # See ContinueCanonicalizationTest.test_multiple_continues for an example
63      # it's necessary to create guards for all enclosing affected blocks, not
64      # just that of the current block.
65      block.create_guard_next = True
66      if block.is_loop_type:
67        # continue only affects the innermost loop
68        break
69    template = """
70      var_name = True
71    """
72    return templates.replace(
73        template, var_name=self.state[_Continue].control_var_name)
74
75  def _postprocess_statement(self, node):
76    if self.state[_Continue].used:
77      block = self.state[_Block]
78      should_wrap_current = block.create_guard_current
79      # After processing propagate whether to guard the next statement
80      block.create_guard_current = block.create_guard_next
81      block.create_guard_next = False
82      if should_wrap_current:
83        template = """
84          if not var_name:
85            original_node
86        """
87        cond, = templates.replace(
88            template,
89            var_name=self.state[_Continue].control_var_name,
90            original_node=node)
91        return cond, cond.body
92    return node, None
93
94  def _visit_loop_body(self, node, nodes):
95    self.state[_Continue].enter()
96    self.state[_Block].enter()
97    self.state[_Block].is_loop_type = True
98    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
99    continue_var = self.ctx.namer.new_symbol('continue_', scope.referenced)
100    self.state[_Continue].control_var_name = continue_var
101
102    nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
103
104    if self.state[_Continue].used:
105      template = """
106        var_name = False
107      """
108      control_var_init = templates.replace(template, var_name=continue_var)
109      nodes = control_var_init + nodes
110
111    self.state[_Block].exit()
112    self.state[_Continue].exit()
113    return nodes
114
115  def _visit_non_loop_body(self, nodes):
116    self.state[_Block].enter()
117    nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
118    self.state[_Block].exit()
119    return nodes
120
121  def visit_While(self, node):
122    node.test = self.visit(node.test)
123    node.body = self._visit_loop_body(node, node.body)
124    # A continue in the else clause applies to the containing scope.
125    node.orelse = self._visit_non_loop_body(node.orelse)
126    return node
127
128  def visit_For(self, node):
129    node.target = self.generic_visit(node.target)
130    node.iter = self.generic_visit(node.iter)
131    node.body = self._visit_loop_body(node, node.body)
132    # A continue in the else clause applies to the containing scope.
133    node.orelse = self._visit_non_loop_body(node.orelse)
134    return node
135
136  def visit_If(self, node):
137    node.body = self._visit_non_loop_body(node.body)
138    node.orelse = self._visit_non_loop_body(node.orelse)
139    return node
140
141  def visit_With(self, node):
142    node.items = self.visit_block(node.items)
143    node.body = self._visit_non_loop_body(node.body)
144    return node
145
146  def visit_Try(self, node):
147    node.body = self._visit_non_loop_body(node.body)
148    node.orelse = self._visit_non_loop_body(node.orelse)
149    # In Python 3.8 and later continue is allowed in finally blocks
150    node.finalbody = self._visit_non_loop_body(node.finalbody)
151    node.handlers = self.visit_block(node.handlers)
152    return node
153
154  def visit_ExceptHandler(self, node):
155    node.body = self._visit_non_loop_body(node.body)
156    return node
157
158
159def transform(node, ctx):
160  node = qual_names.resolve(node)
161  node = activity.resolve(node, ctx, None)
162
163  node = ContinueCanonicalizationTransformer(ctx).visit(node)
164  return node
165