1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Handles directives. 16 17This converter removes the directive functions from the code and moves the 18information they specify into AST annotations. It is a specialized form of 19static analysis, one that is specific to AutoGraph. 20 21Note that this requires that the actual directive functions are static - that 22is, they do not change at runtime. So if you do something like this: 23 24 tf.autograph.set_loop_options = <new function> 25 26Then the directive will may no longer be recognized. Furthermore, if the 27converted function is cached, such an action may be irreversible. 28""" 29 30from __future__ import absolute_import 31from __future__ import division 32from __future__ import print_function 33 34import inspect 35 36import gast 37 38from tensorflow.python.autograph.core import converter 39from tensorflow.python.autograph.lang import directives 40from tensorflow.python.autograph.pyct import anno 41from tensorflow.python.util import tf_inspect 42 43 44STATIC_VALUE = 'static_value' 45"""Used for AST annotations, see visit_Name.""" 46 47 48class _LoopScope(object): 49 50 def __init__(self): 51 self.ast_node = None 52 self.statements_visited = 0 53 54 55def _map_args(call_node, function): 56 """Maps AST call nodes to the actual function's arguments. 57 58 Args: 59 call_node: ast.Call 60 function: Callable[..., Any], the actual function matching call_node 61 Returns: 62 Dict[Text, ast.AST], mapping each of the function's argument names to 63 the respective AST node. 64 Raises: 65 ValueError: if the default arguments are not correctly set 66 """ 67 args = call_node.args 68 kwds = {kwd.arg: kwd.value for kwd in call_node.keywords} 69 call_args = tf_inspect.getcallargs(function, *args, **kwds) 70 71 # Keyword arguments not specified in kwds will be mapped to their defaults, 72 # which are Python values. Since we don't currently have a way to transform 73 # those into AST references, we simply remove them. By convention, directives 74 # use UNSPECIFIED as default value for optional arguments. No other 75 # defaults should be present. 76 unexpected_defaults = [] 77 for k in call_args: 78 if (k not in kwds 79 and call_args[k] not in args 80 and call_args[k] is not directives.UNSPECIFIED): 81 unexpected_defaults.append(k) 82 if unexpected_defaults: 83 raise ValueError('Unexpected keyword argument values, %s, for function %s' 84 % (zip(unexpected_defaults, 85 [call_args[k] for k in unexpected_defaults]), 86 function)) 87 return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED} 88 89 90class DirectivesTransformer(converter.Base): 91 """Parses compiler directives and converts them into AST annotations.""" 92 93 def _process_symbol_directive(self, call_node, directive): 94 if len(call_node.args) < 1: 95 raise ValueError('"%s" requires a positional first argument' 96 ' as the target' % directive.__name__) 97 target = call_node.args[0] 98 defs = anno.getanno(target, anno.Static.ORIG_DEFINITIONS) 99 for def_ in defs: 100 def_.directives[directive] = _map_args(call_node, directive) 101 return call_node 102 103 def _process_statement_directive(self, call_node, directive): 104 if self.state[_LoopScope].statements_visited > 1: 105 raise ValueError( 106 '"%s" must be the first statement in the loop block' % ( 107 directive.__name__)) 108 if self.state[_LoopScope].level < 2: 109 raise ValueError( 110 '"%s" must be used inside a statement' % directive.__name__) 111 target = self.state[_LoopScope].ast_node 112 node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {}) 113 node_anno[directive] = _map_args(call_node, directive) 114 anno.setanno(target, anno.Basic.DIRECTIVES, node_anno) 115 return call_node 116 117 def visit_Name(self, node): 118 node = self.generic_visit(node) 119 if isinstance(node.ctx, gast.Load): 120 defs = anno.getanno(node, anno.Static.DEFINITIONS, ()) 121 is_defined = bool(defs) 122 if not is_defined and node.id in self.ctx.info.namespace: 123 anno.setanno(node, STATIC_VALUE, self.ctx.info.namespace[node.id]) 124 return node 125 126 def visit_Attribute(self, node): 127 node = self.generic_visit(node) 128 parent_val = anno.getanno(node.value, STATIC_VALUE, default=None) 129 if parent_val is not None and inspect.ismodule(parent_val): 130 if hasattr(parent_val, node.attr): 131 anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr)) 132 return node 133 134 def visit_Assign(self, node): 135 self.state[_LoopScope].statements_visited += 1 136 return self.generic_visit(node) 137 138 def visit_AugAssign(self, node): 139 self.state[_LoopScope].statements_visited += 1 140 return self.generic_visit(node) 141 142 def visit_Expr(self, node): 143 self.state[_LoopScope].statements_visited += 1 144 node = self.generic_visit(node) 145 if isinstance(node.value, gast.Call): 146 call_node = node.value 147 static_val = anno.getanno(call_node.func, STATIC_VALUE, default=None) 148 if static_val is not None: 149 # Note: directive calls are not output in the generated code, hence 150 # the removal from the code by returning None. 151 152 if static_val is directives.set_element_type: 153 self._process_symbol_directive(call_node, static_val) 154 return None 155 elif static_val is directives.set_loop_options: 156 self._process_statement_directive(call_node, static_val) 157 return None 158 return node 159 160 # TODO(mdan): This will be insufficient for other control flow. 161 # That means that if we ever have a directive that affects things other than 162 # loops, we'll need support for parallel scopes, or have multiple converters. 163 def _track_and_visit_loop(self, node): 164 self.state[_LoopScope].enter() 165 self.state[_LoopScope].ast_node = node 166 node = self.generic_visit(node) 167 # Edge case: a loop with just one directive statement would become empty. 168 if not node.body: 169 node.body = [gast.Pass()] 170 self.state[_LoopScope].exit() 171 return node 172 173 def visit_While(self, node): 174 return self._track_and_visit_loop(node) 175 176 def visit_For(self, node): 177 return self._track_and_visit_loop(node) 178 179 180def transform(node, ctx): 181 return DirectivesTransformer(ctx).visit(node) 182