• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles function calls, by generating compiled function names and calls.
16
17Note: this transformer does not rename the top level object being converted;
18that is the caller's responsibility.
19
20Requires function_scopes.
21"""
22
23import gast
24
25from tensorflow.python.autograph.core import converter
26from tensorflow.python.autograph.pyct import anno
27from tensorflow.python.autograph.pyct import parser
28from tensorflow.python.autograph.pyct import qual_names
29from tensorflow.python.autograph.pyct import templates
30from tensorflow.python.autograph.utils import ag_logging
31
32
33# TODO(mdan): Rename to FunctionCallsTransformer.
34
35
36class _Function(object):
37
38  no_root = True
39
40  def __init__(self):
41    self.context_name = None
42
43
44set_trace_warned = False
45
46
47class _ArgTemplateBuilder(object):
48  """Constructs a tuple representing the positional arguments in a call.
49
50  Example (yes, it's legal Python 3):
51
52      f(*args1, b, *args2, c, d)  ->  args1 + (b,) + args2 + (c, d)
53  """
54
55  def __init__(self):
56    self._arg_accumulator = []
57    self._argspec = []
58    self._finalized = False
59
60  def _consume_args(self):
61    if self._arg_accumulator:
62      self._argspec.append(
63          gast.Tuple(elts=self._arg_accumulator, ctx=gast.Load()))
64      self._arg_accumulator = []
65
66  def add_arg(self, a):
67    self._arg_accumulator.append(a)
68
69  def add_stararg(self, a):
70    self._consume_args()
71    self._argspec.append(
72        gast.Call(
73            gast.Name(
74                'tuple', ctx=gast.Load(), annotation=None, type_comment=None),
75            args=[a],
76            keywords=()))
77
78  def finalize(self):
79    self._consume_args()
80    self._finalized = True
81
82  def to_ast(self):
83    assert self._finalized
84    if self._argspec:
85      result = self._argspec[0]
86      for i in range(1, len(self._argspec)):
87        result = gast.BinOp(result, gast.Add(), self._argspec[i])
88      return result
89    return gast.Tuple([], gast.Load())
90
91
92class CallTreeTransformer(converter.Base):
93  """Transforms the call tree by renaming transformed symbols."""
94
95  def visit_Lambda(self, node):
96    if not anno.hasanno(node, 'function_context_name'):
97      # Lambda functions created during the conversion process have no
98      # context manager.
99      return self.generic_visit(node)
100    with self.state[_Function] as fn_scope:
101      fn_scope.context_name = anno.getanno(node, 'function_context_name')
102      return self.generic_visit(node)
103
104  def visit_FunctionDef(self, node):
105    # Decorators and arg defaults are part of the outer scope.
106    node.decorator_list = self.visit_block(node.decorator_list)
107    node.args.defaults = self.visit_block(node.args.defaults)
108    for i, d in enumerate(node.args.kw_defaults):
109      if d is not None:
110        node.args.kw_defaults[i] = self.visit(d)
111    with self.state[_Function] as fn_scope:
112      # Note: if the conversion process ever creates helper functions, this
113      # assumption will no longer hold.
114      assert anno.hasanno(node, 'function_context_name'), (
115          'The function_scopes converter always creates a scope for functions.')
116      fn_scope.context_name = anno.getanno(node, 'function_context_name')
117      node.body = self.visit_block(node.body)
118      if node.returns:
119        node.returns = self.visit(node.returns)
120      return node
121
122  def visit_With(self, node):
123    # Context manager calls (in node.items) are not converted.
124    node.body = self.visit_block(node.body)
125    return node
126
127  def _args_to_tuple(self, node):
128    """Ties together all positional and *arg arguments in a single tuple."""
129    # TODO(mdan): We could rewrite this to just a call to tuple(). Maybe better?
130    # For example for
131    #   f(a, b, *args)
132    # instead of writing:
133    #   (a, b) + args
134    # just write this?
135    #   tuple(a, b, *args)
136    builder = _ArgTemplateBuilder()
137    for a in node.args:
138      if isinstance(a, gast.Starred):
139        builder.add_stararg(a.value)
140      else:
141        builder.add_arg(a)
142    builder.finalize()
143    return builder.to_ast()
144
145  def _kwargs_to_dict(self, node):
146    """Ties together all keyword and **kwarg arguments in a single dict."""
147    if node.keywords:
148      return gast.Call(
149          gast.Name(
150              'dict', ctx=gast.Load(), annotation=None, type_comment=None),
151          args=(),
152          keywords=node.keywords)
153    else:
154      return parser.parse_expression('None')
155
156  def visit_Call(self, node):
157    full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
158    function_context_name = self.state[_Function].context_name
159    node = self.generic_visit(node)
160
161    # TODO(mdan): Refactor converted_call as a 'Call' operator.
162
163    # Calls to the internal 'ag__' module are never converted (though their
164    # arguments might be).
165    if full_name.startswith('ag__.'):
166      return node
167
168    # Calls to the function context manager (inserted by function_scopes) are
169    # also safe.
170    if full_name.startswith(function_context_name + '.'):
171      return node
172
173    # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
174    # the normal mechanisms to bypass these literals because they are sensitive
175    # to the frame they are being called from.
176    # TODO(mdan): Generalize this to a "static allowlist" config.
177    if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'):
178      global set_trace_warned
179      if not set_trace_warned:
180        # TODO(mdan): Update and shorten once available on tensorflow.org.
181        ag_logging.warning(
182            'Detected `pdb.set_trace()` in user code. The code'
183            ' generated by AutoGraph is not optimized for step-by-step'
184            ' debugging. See https://github.com/tensorflow/tensorflow/'
185            'blob/master/tensorflow/python/autograph/g3doc/reference/'
186            'debugging.md.')
187        set_trace_warned = True
188      return node
189
190    if (full_name == 'print' and
191        not self.ctx.user.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
192      return node
193
194    template = """
195      ag__.converted_call(func, args, kwargs, function_ctx)
196    """
197    new_call = templates.replace_as_expression(
198        template,
199        func=node.func,
200        args=self._args_to_tuple(node),
201        kwargs=self._kwargs_to_dict(node),
202        function_ctx=function_context_name)
203
204    return new_call
205
206
207def transform(node, ctx):
208  """Transform function call to the compiled counterparts.
209
210  Args:
211    node: AST
212    ctx: EntityContext
213  Returns:
214    A tuple (node, new_names):
215        node: The transformed AST
216        new_names: set(string), containing any newly-generated names
217  """
218  node = qual_names.resolve(node)
219
220  node = CallTreeTransformer(ctx).visit(node)
221  return node
222