• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles function calls, by generating compiled function names and calls.
16
17Note: this transformer does not rename the top level object being converted;
18that is the caller's responsibility.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import gast
26
27from tensorflow.python.autograph.core import converter
28from tensorflow.python.autograph.pyct import anno
29from tensorflow.python.autograph.pyct import ast_util
30from tensorflow.python.autograph.pyct import parser
31from tensorflow.python.autograph.pyct import templates
32
33
34# TODO(mdan): Rename to FunctionCallsTransformer.
35
36
37class _Function(object):
38
39  no_root = True
40
41
42class CallTreeTransformer(converter.Base):
43  """Transforms the call tree by renaming transformed symbols."""
44
45  def visit_FunctionDef(self, node):
46    self.state[_Function].enter()
47    node.args = self.visit(node.args)
48    node.body = self.visit_block(node.body)
49
50    if self.state[_Function].level < 2:
51      # Top-level functions lose their decorator because the conversion is
52      # always just-in-time and by the time it happens the decorators are
53      # already set to be applied.
54      node.decorator_list = []
55    else:
56      # Inner functions are converted already, so we insert a decorator to
57      # prevent double conversion. Double conversion would work too, but this
58      # saves the overhead.
59      node.decorator_list.append(
60          parser.parse_expression('ag__.do_not_convert_internal'))
61
62    if node.returns:
63      node.returns = self.visit(node.returns)
64
65    self.state[_Function].exit()
66    return node
67
68  def visit_With(self, node):
69    # Context manager calls (in node.items) are not converted.
70    node.body = self.visit_block(node.body)
71    return node
72
73  def visit_Call(self, node):
74    # TODO(mdan): Refactor converted_call as a 'Call' operator.
75
76    # Calls to the internal 'ag__' module are never converted (though their
77    # arguments might be).
78    full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
79    if full_name.startswith('ag__.'):
80      return self.generic_visit(node)
81    if (full_name == 'print' and
82        not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
83      return self.generic_visit(node)
84
85    if isinstance(node.func, gast.Attribute):
86      func = gast.Str(node.func.attr)
87      owner = node.func.value
88    else:
89      func = node.func
90      owner = parser.parse_expression('None')
91
92    starred_arg = None
93    normal_args = []
94    for a in node.args:
95      if isinstance(a, gast.Starred):
96        assert starred_arg is None, 'Multiple *args should be impossible.'
97        starred_arg = a
98      else:
99        normal_args.append(a)
100    if starred_arg is None:
101      args = templates.replace_as_expression('(args,)', args=normal_args)
102    else:
103      args = templates.replace_as_expression(
104          '(args,) + tuple(stararg)',
105          stararg=starred_arg.value,
106          args=normal_args)
107
108    kwargs_arg = None
109    normal_keywords = []
110    for k in node.keywords:
111      if k.arg is None:
112        assert kwargs_arg is None, 'Multiple **kwargs should be impossible.'
113        kwargs_arg = k
114      else:
115        normal_keywords.append(k)
116    if kwargs_arg is None:
117      kwargs = ast_util.keywords_to_dict(normal_keywords)
118    else:
119      kwargs = templates.replace_as_expression(
120          'dict(kwargs, **keywords)',
121          kwargs=kwargs_arg.value,
122          keywords=ast_util.keywords_to_dict(normal_keywords))
123
124    template = """
125      ag__.converted_call(func, owner, options, args, kwargs)
126    """
127    new_call = templates.replace_as_expression(
128        template,
129        func=func,
130        owner=owner,
131        options=self.ctx.program.options.to_ast(
132            internal_convert_user_code=self.ctx.program.options.recursive),
133        args=args,
134        kwargs=kwargs)
135
136    return new_call
137
138
139def transform(node, ctx):
140  """Transform function call to the compiled counterparts.
141
142  Args:
143    node: AST
144    ctx: EntityContext
145  Returns:
146    A tuple (node, new_names):
147        node: The transformed AST
148        new_names: set(string), containing any newly-generated names
149  """
150  return CallTreeTransformer(ctx).visit(node)
151