• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Canonicalizes functions with multiple returns to use just one."""
16
17import gast
18
19from tensorflow.python.autograph.core import converter
20from tensorflow.python.autograph.pyct import anno
21from tensorflow.python.autograph.pyct import parser
22from tensorflow.python.autograph.pyct import qual_names
23from tensorflow.python.autograph.pyct import templates
24from tensorflow.python.autograph.pyct.static_analysis import activity
25from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
26
27
28BODY_DEFINITELY_RETURNS = 'BODY_DEFINITELY_RETURNS'
29ORELSE_DEFINITELY_RETURNS = 'ORELSE_DEFINITELY_RETURNS'
30STMT_DEFINITELY_RETURNS = 'STMT_DEFINITELY_RETURNS'
31
32
33class _RewriteBlock(object):
34
35  def __init__(self):
36    self.definitely_returns = False
37
38
39class ConditionalReturnRewriter(converter.Base):
40  """Rewrites a pattern where it's unobvious that all paths return a value.
41
42  This rewrite allows avoiding intermediate None return values.
43
44  The following pattern:
45
46      if cond:
47        <block 1>
48        return
49      else:
50        <block 2>
51      <block 3>
52
53  is converted to:
54
55      if cond:
56        <block 1>
57        return
58      else:
59        <block 2>
60        <block 3>
61
62  and vice-versa (if the else returns, subsequent statements are moved under the
63  if branch).
64  """
65
66  def visit_Return(self, node):
67    self.state[_RewriteBlock].definitely_returns = True
68    return node
69
70  def _postprocess_statement(self, node):
71    # If the node definitely returns (e.g. it's a with statement with a
72    # return statement in it), then the current block also definitely returns.
73    if anno.getanno(node, STMT_DEFINITELY_RETURNS, default=False):
74      self.state[_RewriteBlock].definitely_returns = True
75
76    # The special case: collapse a typical conditional return pattern into
77    # a single conditional with possibly returns on both branches. This
78    # reduces the use of None return values, which don't work with TF
79    # conditionals.
80    if (isinstance(node, gast.If)
81        and anno.getanno(node, BODY_DEFINITELY_RETURNS, default=False)):
82      return node, node.orelse
83    elif (isinstance(node, gast.If)
84          and anno.getanno(node, ORELSE_DEFINITELY_RETURNS, default=False)):
85      return node, node.body
86
87    return node, None
88
89  def _visit_statement_block(self, node, nodes):
90    self.state[_RewriteBlock].enter()
91    new_nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
92    block_definitely_returns = self.state[_RewriteBlock].definitely_returns
93    self.state[_RewriteBlock].exit()
94    return new_nodes, block_definitely_returns
95
96  def visit_While(self, node):
97    node.test = self.visit(node.test)
98    node.body, _ = self._visit_statement_block(node, node.body)
99    node.orelse, _ = self._visit_statement_block(node, node.orelse)
100    return node
101
102  def visit_For(self, node):
103    node.iter = self.visit(node.iter)
104    node.target = self.visit(node.target)
105    node.body, _ = self._visit_statement_block(node, node.body)
106    node.orelse, _ = self._visit_statement_block(node, node.orelse)
107    return node
108
109  def visit_With(self, node):
110    node.items = self.visit_block(node.items)
111    node.body, definitely_returns = self._visit_statement_block(node, node.body)
112    if definitely_returns:
113      anno.setanno(node, STMT_DEFINITELY_RETURNS, True)
114    return node
115
116  def visit_Try(self, node):
117    # We could decide whether a 'try' DEFINITELY_RETURNS based on its components
118    # It is not clear whether we want to do anything with this given
119    # a 'try' is likely to throw an exception in some circumstances.
120    node.body, _ = self._visit_statement_block(node, node.body)
121    node.orelse, _ = self._visit_statement_block(node, node.orelse)
122    node.finalbody, _ = self._visit_statement_block(node, node.finalbody)
123    node.handlers = self.visit_block(node.handlers)
124    return node
125
126  def visit_ExceptHandler(self, node):
127    # To determine whether `try` DEFINITELY_RETURNS we need to revisit this.
128    node.body, _ = self._visit_statement_block(node, node.body)
129    return node
130
131  def visit_If(self, node):
132    node.test = self.visit(node.test)
133
134    node.body, body_definitely_returns = self._visit_statement_block(
135        node, node.body)
136    if body_definitely_returns:
137      anno.setanno(node, BODY_DEFINITELY_RETURNS, True)
138
139    node.orelse, orelse_definitely_returns = self._visit_statement_block(
140        node, node.orelse)
141    if orelse_definitely_returns:
142      anno.setanno(node, ORELSE_DEFINITELY_RETURNS, True)
143
144    if body_definitely_returns and orelse_definitely_returns:
145      self.state[_RewriteBlock].definitely_returns = True
146
147    return node
148
149  def visit_FunctionDef(self, node):
150    node.args = self.visit(node.args)
151    node.body, _ = self._visit_statement_block(node, node.body)
152    return node
153
154
155class _Block(object):
156
157  def __init__(self):
158    self.is_function = False
159    self.return_used = False
160    self.create_guard_next = False
161    self.create_guard_now = False
162
163  def __repr__(self):
164    return 'used: {}'.format(
165        self.return_used)
166
167
168class _Function(object):
169
170  def __init__(self):
171    self.do_return_var_name = None
172    self.retval_var_name = None
173
174  def __repr__(self):
175    return 'return control: {}, return value: {}'.format(
176        self.do_return_var_name, self.retval_var_name)
177
178
179class ReturnStatementsTransformer(converter.Base):
180  """Lowers return statements into variables and conditionals.
181
182  Specifically, the following pattern:
183
184      <block 1>
185      return val
186      <block 2>
187
188  is converted to:
189
190      do_return = False
191      retval = None
192
193      <block 1>
194
195      do_return = True
196      retval = val
197
198      if not do_return:
199        <block 2>
200
201      return retval
202
203  The conversion adjusts loops as well:
204
205      <block 1>
206      while cond:
207        <block 2>
208        return retval
209
210  is converted to:
211
212      <block 1>
213      while not do_return and cond:
214        <block 2>
215        do_return = True
216        retval = val
217  """
218
219  def __init__(self, ctx, allow_missing_return):
220    super(ReturnStatementsTransformer, self).__init__(ctx)
221    self.allow_missing_return = allow_missing_return
222
223  def visit_Return(self, node):
224    for block in reversed(self.state[_Block].stack):
225      block.return_used = True
226      block.create_guard_next = True
227      if block.is_function:
228        break
229
230    retval = node.value if node.value else parser.parse_expression('None')
231
232    # Note: If `return <expr> raises, then the return is aborted.
233    # The try-catch below ensures the variables remain consistent in that case.
234    template = """
235      try:
236        do_return_var_name = True
237        retval_var_name = retval
238      except:
239        do_return_var_name = False
240        raise
241    """
242    node = templates.replace(
243        template,
244        do_return_var_name=self.state[_Function].do_return_var_name,
245        retval_var_name=self.state[_Function].retval_var_name,
246        retval=retval)
247
248    return node
249
250  def _postprocess_statement(self, node):
251    if not self.state[_Block].return_used:
252      return node, None
253
254    state = self.state[_Block]
255    if state.create_guard_now:
256      template = """
257        if not do_return_var_name:
258          original_node
259      """
260      cond, = templates.replace(
261          template,
262          do_return_var_name=self.state[_Function].do_return_var_name,
263          original_node=node)
264      node, block = cond, cond.body
265    else:
266      node, block = node, None
267
268    state.create_guard_now = state.create_guard_next
269    state.create_guard_next = False
270
271    return node, block
272
273  def _visit_statement_block(self, node, nodes):
274    self.state[_Block].enter()
275    nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
276    self.state[_Block].exit()
277    return nodes
278
279  def visit_While(self, node):
280    node.test = self.visit(node.test)
281
282    # Add the check for return to the loop condition.
283    node.body = self._visit_statement_block(node, node.body)
284    if self.state[_Block].return_used:
285      node.test = templates.replace_as_expression(
286          'not control_var and test',
287          test=node.test,
288          control_var=self.state[_Function].do_return_var_name)
289
290    node.orelse = self._visit_statement_block(node, node.orelse)
291    return node
292
293  def visit_For(self, node):
294    node.iter = self.visit(node.iter)
295    node.target = self.visit(node.target)
296
297    # Add the check for return to the loop condition.
298    node.body = self._visit_statement_block(node, node.body)
299    if self.state[_Block].return_used:
300      extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
301      if extra_test is not None:
302        extra_test = templates.replace_as_expression(
303            'not control_var and extra_test',
304            extra_test=extra_test,
305            control_var=self.state[_Function].do_return_var_name)
306      else:
307        extra_test = templates.replace_as_expression(
308            'not control_var',
309            control_var=self.state[_Function].do_return_var_name)
310      anno.setanno(node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
311
312    node.orelse = self._visit_statement_block(node, node.orelse)
313    return node
314
315  def visit_With(self, node):
316    node.items = self.visit_block(node.items)
317    node.body = self._visit_statement_block(node, node.body)
318    return node
319
320  def visit_Try(self, node):
321    node.body = self._visit_statement_block(node, node.body)
322    node.orelse = self._visit_statement_block(node, node.orelse)
323    node.finalbody = self._visit_statement_block(node, node.finalbody)
324    node.handlers = self.visit_block(node.handlers)
325    return node
326
327  def visit_ExceptHandler(self, node):
328    node.body = self._visit_statement_block(node, node.body)
329    return node
330
331  def visit_If(self, node):
332    node.test = self.visit(node.test)
333    node.body = self._visit_statement_block(node, node.body)
334    node.orelse = self._visit_statement_block(node, node.orelse)
335    return node
336
337  def visit_FunctionDef(self, node):
338    with self.state[_Function] as fn:
339      with self.state[_Block] as block:
340        block.is_function = True
341
342        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
343        do_return_var_name = self.ctx.namer.new_symbol('do_return',
344                                                       scope.referenced)
345        retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced)
346        fn.do_return_var_name = do_return_var_name
347        fn.retval_var_name = retval_var_name
348
349        node.body = self._visit_statement_block(node, node.body)
350
351        if block.return_used:
352
353          if self.allow_missing_return:
354            # The function would have a single `with` node that wraps the
355            # entire body. If the function had a docstring, the body has two
356            # nodes, with the `with` as the second node.
357            wrapper_node = node.body[-1]
358            assert isinstance(wrapper_node, gast.With), (
359                'This transformer requires the functions converter.')
360
361            template = """
362              do_return_var_name = False
363              retval_var_name = ag__.UndefinedReturnValue()
364              body
365              return function_context.ret(retval_var_name, do_return_var_name)
366            """
367
368            wrapper_node.body = templates.replace(
369                template,
370                body=wrapper_node.body,
371                do_return_var_name=do_return_var_name,
372                function_context=anno.getanno(node, 'function_context_name'),
373                retval_var_name=retval_var_name)
374          else:
375            template = """
376              body
377              return retval_var_name
378            """
379            node.body = templates.replace(
380                template,
381                body=node.body,
382                do_return_var_name=do_return_var_name,
383                retval_var_name=retval_var_name)
384
385    return node
386
387
388def transform(node, ctx, default_to_null_return=True):
389  """Ensure a function has only a single return, at the end."""
390  node = qual_names.resolve(node)
391  node = activity.resolve(node, ctx, None)
392
393  # Note: Technically, these two could be merged into a single walk, but
394  # keeping them separate helps with readability.
395  node = ConditionalReturnRewriter(ctx).visit(node)
396
397  node = qual_names.resolve(node)
398  node = activity.resolve(node, ctx, None)
399  transformer = ReturnStatementsTransformer(
400      ctx, allow_missing_return=default_to_null_return)
401  node = transformer.visit(node)
402  return node
403