1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Handles function calls, by generating compiled function names and calls. 16 17Note: this transformer does not rename the top level object being converted; 18that is the caller's responsibility. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import gast 26 27from tensorflow.python.autograph.core import converter 28from tensorflow.python.autograph.pyct import anno 29from tensorflow.python.autograph.pyct import ast_util 30from tensorflow.python.autograph.pyct import parser 31from tensorflow.python.autograph.pyct import templates 32 33 34# TODO(mdan): Rename to FunctionCallsTransformer. 35 36 37class _Function(object): 38 39 no_root = True 40 41 42class CallTreeTransformer(converter.Base): 43 """Transforms the call tree by renaming transformed symbols.""" 44 45 def visit_FunctionDef(self, node): 46 self.state[_Function].enter() 47 node.args = self.visit(node.args) 48 node.body = self.visit_block(node.body) 49 50 if self.state[_Function].level < 2: 51 # Top-level functions lose their decorator because the conversion is 52 # always just-in-time and by the time it happens the decorators are 53 # already set to be applied. 54 node.decorator_list = [] 55 else: 56 # Inner functions are converted already, so we insert a decorator to 57 # prevent double conversion. Double conversion would work too, but this 58 # saves the overhead. 59 node.decorator_list.append( 60 parser.parse_expression('ag__.do_not_convert_internal')) 61 62 if node.returns: 63 node.returns = self.visit(node.returns) 64 65 self.state[_Function].exit() 66 return node 67 68 def visit_With(self, node): 69 # Context manager calls (in node.items) are not converted. 70 node.body = self.visit_block(node.body) 71 return node 72 73 def visit_Call(self, node): 74 # TODO(mdan): Refactor converted_call as a 'Call' operator. 75 76 # Calls to the internal 'ag__' module are never converted (though their 77 # arguments might be). 78 full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) 79 if full_name.startswith('ag__.'): 80 return self.generic_visit(node) 81 if (full_name == 'print' and 82 not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): 83 return self.generic_visit(node) 84 85 if isinstance(node.func, gast.Attribute): 86 func = gast.Str(node.func.attr) 87 owner = node.func.value 88 else: 89 func = node.func 90 owner = parser.parse_expression('None') 91 92 starred_arg = None 93 normal_args = [] 94 for a in node.args: 95 if isinstance(a, gast.Starred): 96 assert starred_arg is None, 'Multiple *args should be impossible.' 97 starred_arg = a 98 else: 99 normal_args.append(a) 100 if starred_arg is None: 101 args = templates.replace_as_expression('(args,)', args=normal_args) 102 else: 103 args = templates.replace_as_expression( 104 '(args,) + tuple(stararg)', 105 stararg=starred_arg.value, 106 args=normal_args) 107 108 kwargs_arg = None 109 normal_keywords = [] 110 for k in node.keywords: 111 if k.arg is None: 112 assert kwargs_arg is None, 'Multiple **kwargs should be impossible.' 113 kwargs_arg = k 114 else: 115 normal_keywords.append(k) 116 if kwargs_arg is None: 117 kwargs = ast_util.keywords_to_dict(normal_keywords) 118 else: 119 kwargs = templates.replace_as_expression( 120 'dict(kwargs, **keywords)', 121 kwargs=kwargs_arg.value, 122 keywords=ast_util.keywords_to_dict(normal_keywords)) 123 124 template = """ 125 ag__.converted_call(func, owner, options, args, kwargs) 126 """ 127 new_call = templates.replace_as_expression( 128 template, 129 func=func, 130 owner=owner, 131 options=self.ctx.program.options.to_ast( 132 internal_convert_user_code=self.ctx.program.options.recursive), 133 args=args, 134 kwargs=kwargs) 135 136 return new_call 137 138 139def transform(node, ctx): 140 """Transform function call to the compiled counterparts. 141 142 Args: 143 node: AST 144 ctx: EntityContext 145 Returns: 146 A tuple (node, new_names): 147 node: The transformed AST 148 new_names: set(string), containing any newly-generated names 149 """ 150 return CallTreeTransformer(ctx).visit(node) 151