• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converts function definitions and lambdas by adding necessary boilerplate."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gast
22
23from tensorflow.python.autograph.core import converter
24from tensorflow.python.autograph.pyct import anno
25from tensorflow.python.autograph.pyct import parser
26from tensorflow.python.autograph.pyct import qual_names
27from tensorflow.python.autograph.pyct import templates
28from tensorflow.python.autograph.pyct.static_analysis import activity
29from tensorflow.python.autograph.pyct.static_analysis import annos
30
31
32class _Function(object):
33
34  def __init__(self):
35    self.context_name = None
36
37
38class FunctionTransformer(converter.Base):
39  """Wraps function bodies around autograph-specific boilerplate."""
40
41  def _function_scope_options(self, fn_scope):
42    """Returns the options with which to create function scopes."""
43    # Top-level function receive the options that were directly requested.
44    # All others receive the options corresponding to a recursive conversion.
45    # Note: this mainly controls the user_requested flag, which is important
46    # primarily because the FunctionScope context also creates a
47    # ControlStatusCtx(autograph=ENABLED) when user_requested is True. See
48    # function_wrappers.py.
49    if fn_scope.level == 2:
50      return self.ctx.user.options
51    return self.ctx.user.options.call_options()
52
53  def visit_Lambda(self, node):
54    with self.state[_Function] as fn_scope:
55      node = self.generic_visit(node)
56
57      # TODO(mdan): Fix the tests so that we can always add this decorator.
58      if fn_scope.level > 2:
59        return templates.replace_as_expression(
60            'ag__.autograph_artifact(l)', l=node)
61
62      scope = anno.getanno(node, anno.Static.SCOPE)
63      function_context_name = self.ctx.namer.new_symbol('lscope',
64                                                        scope.referenced)
65      fn_scope.context_name = function_context_name
66      anno.setanno(node, 'function_context_name', function_context_name)
67
68      template = """
69        ag__.with_function_scope(
70            lambda function_context: body, function_context_name, options)
71      """
72      node.body = templates.replace_as_expression(
73          template,
74          options=self._function_scope_options(fn_scope).to_ast(),
75          function_context=function_context_name,
76          function_context_name=gast.Constant(function_context_name, kind=None),
77          body=node.body)
78
79      return node
80
81  def visit_FunctionDef(self, node):
82    with self.state[_Function] as fn_scope:
83      scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
84
85      function_context_name = self.ctx.namer.new_symbol('fscope',
86                                                        scope.referenced)
87      fn_scope.context_name = function_context_name
88      anno.setanno(node, 'function_context_name', function_context_name)
89
90      node = self.generic_visit(node)
91
92      if fn_scope.level <= 2:
93        # Top-level functions lose their decorator because the conversion is
94        # always just-in-time and by the time it happens the decorators are
95        # already set to be applied.
96        node.decorator_list = []
97      else:
98        # TODO(mdan): Fix the tests so that we can always add this decorator.
99        # Inner functions are converted already, so we insert a decorator to
100        # prevent double conversion. Double conversion would work too, but this
101        # saves the overhead.
102        node.decorator_list.append(
103            parser.parse_expression('ag__.autograph_artifact'))
104
105      docstring_node = None
106      if node.body:
107        first_statement = node.body[0]
108        if (isinstance(first_statement, gast.Expr) and
109            isinstance(first_statement.value, gast.Constant)):
110          docstring_node = first_statement
111          node.body = node.body[1:]
112
113      template = """
114        with ag__.FunctionScope(
115            function_name, context_name, options) as function_context:
116          body
117      """
118      wrapped_body = templates.replace(
119          template,
120          function_name=gast.Constant(node.name, kind=None),
121          context_name=gast.Constant(function_context_name, kind=None),
122          options=self._function_scope_options(fn_scope).to_ast(),
123          function_context=function_context_name,
124          body=node.body)
125
126      if docstring_node is not None:
127        wrapped_body = [docstring_node] + wrapped_body
128
129      node.body = wrapped_body
130
131      return node
132
133
134def transform(node, ctx):
135  node = qual_names.resolve(node)
136  node = activity.resolve(node, ctx, None)
137
138  return FunctionTransformer(ctx).visit(node)
139