1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converts function definitions and lambdas by adding necessary boilerplate.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import gast 22 23from tensorflow.python.autograph.core import converter 24from tensorflow.python.autograph.pyct import anno 25from tensorflow.python.autograph.pyct import parser 26from tensorflow.python.autograph.pyct import qual_names 27from tensorflow.python.autograph.pyct import templates 28from tensorflow.python.autograph.pyct.static_analysis import activity 29from tensorflow.python.autograph.pyct.static_analysis import annos 30 31 32class _Function(object): 33 34 def __init__(self): 35 self.context_name = None 36 37 38class FunctionTransformer(converter.Base): 39 """Wraps function bodies around autograph-specific boilerplate.""" 40 41 def _function_scope_options(self, fn_scope): 42 """Returns the options with which to create function scopes.""" 43 # Top-level function receive the options that were directly requested. 44 # All others receive the options corresponding to a recursive conversion. 45 # Note: this mainly controls the user_requested flag, which is important 46 # primarily because the FunctionScope context also creates a 47 # ControlStatusCtx(autograph=ENABLED) when user_requested is True. See 48 # function_wrappers.py. 49 if fn_scope.level == 2: 50 return self.ctx.user.options 51 return self.ctx.user.options.call_options() 52 53 def visit_Lambda(self, node): 54 with self.state[_Function] as fn_scope: 55 node = self.generic_visit(node) 56 57 # TODO(mdan): Fix the tests so that we can always add this decorator. 58 if fn_scope.level > 2: 59 return templates.replace_as_expression( 60 'ag__.autograph_artifact(l)', l=node) 61 62 scope = anno.getanno(node, anno.Static.SCOPE) 63 function_context_name = self.ctx.namer.new_symbol('lscope', 64 scope.referenced) 65 fn_scope.context_name = function_context_name 66 anno.setanno(node, 'function_context_name', function_context_name) 67 68 template = """ 69 ag__.with_function_scope( 70 lambda function_context: body, function_context_name, options) 71 """ 72 node.body = templates.replace_as_expression( 73 template, 74 options=self._function_scope_options(fn_scope).to_ast(), 75 function_context=function_context_name, 76 function_context_name=gast.Constant(function_context_name, kind=None), 77 body=node.body) 78 79 return node 80 81 def visit_FunctionDef(self, node): 82 with self.state[_Function] as fn_scope: 83 scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 84 85 function_context_name = self.ctx.namer.new_symbol('fscope', 86 scope.referenced) 87 fn_scope.context_name = function_context_name 88 anno.setanno(node, 'function_context_name', function_context_name) 89 90 node = self.generic_visit(node) 91 92 if fn_scope.level <= 2: 93 # Top-level functions lose their decorator because the conversion is 94 # always just-in-time and by the time it happens the decorators are 95 # already set to be applied. 96 node.decorator_list = [] 97 else: 98 # TODO(mdan): Fix the tests so that we can always add this decorator. 99 # Inner functions are converted already, so we insert a decorator to 100 # prevent double conversion. Double conversion would work too, but this 101 # saves the overhead. 102 node.decorator_list.append( 103 parser.parse_expression('ag__.autograph_artifact')) 104 105 docstring_node = None 106 if node.body: 107 first_statement = node.body[0] 108 if (isinstance(first_statement, gast.Expr) and 109 isinstance(first_statement.value, gast.Constant)): 110 docstring_node = first_statement 111 node.body = node.body[1:] 112 113 template = """ 114 with ag__.FunctionScope( 115 function_name, context_name, options) as function_context: 116 body 117 """ 118 wrapped_body = templates.replace( 119 template, 120 function_name=gast.Constant(node.name, kind=None), 121 context_name=gast.Constant(function_context_name, kind=None), 122 options=self._function_scope_options(fn_scope).to_ast(), 123 function_context=function_context_name, 124 body=node.body) 125 126 if docstring_node is not None: 127 wrapped_body = [docstring_node] + wrapped_body 128 129 node.body = wrapped_body 130 131 return node 132 133 134def transform(node, ctx): 135 node = qual_names.resolve(node) 136 node = activity.resolve(node, ctx, None) 137 138 return FunctionTransformer(ctx).visit(node) 139