1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""AST conversion templates. 16 17Adapted from Tangent. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import ast 25import textwrap 26 27import gast 28 29from tensorflow.python.autograph.pyct import anno 30from tensorflow.python.autograph.pyct import ast_util 31from tensorflow.python.autograph.pyct import parser 32from tensorflow.python.autograph.pyct import qual_names 33 34 35class ContextAdjuster(gast.NodeTransformer): 36 """Adjusts the ctx field of nodes to ensure consistency. 37 38 This transformer can change the ctx fields of a variable, tuple and other 39 AST elements that allow one, based on whether the element is being read or 40 written. 41 """ 42 43 def __init__(self, override_value): 44 self._ctx_override = override_value 45 46 def visit(self, node): 47 original_override = self._ctx_override 48 node = super(ContextAdjuster, self).visit(node) 49 if hasattr(node, 'ctx'): 50 assert node.ctx is not None, 'node {} has ctx unset'.format(node) 51 self._ctx_override = original_override 52 return node 53 54 def _apply_override(self, node): 55 if self._ctx_override is not None: 56 node.ctx = self._ctx_override() 57 58 def visit_Attribute(self, node): 59 self._apply_override(node) 60 self._ctx_override = gast.Load 61 node = self.generic_visit(node) 62 return node 63 64 def visit_Tuple(self, node): 65 self._apply_override(node) 66 return self.generic_visit(node) 67 68 def visit_List(self, node): 69 self._apply_override(node) 70 return self.generic_visit(node) 71 72 def visit_Name(self, node): 73 self._apply_override(node) 74 return self.generic_visit(node) 75 76 def visit_Call(self, node): 77 self._apply_override(node) 78 # We may be able to override these to Load(), but for now it's simpler 79 # to just assert that they're set. 80 self._ctx_override = None 81 return self.generic_visit(node) 82 83 def visit_Dict(self, node): 84 # We may be able to override these to Load(), but for now it's simpler 85 # to just assert that they're set. 86 self._ctx_override = None 87 return self.generic_visit(node) 88 89 def visit_Subscript(self, node): 90 self._apply_override(node) 91 self._ctx_override = gast.Load 92 node.value = self.visit(node.value) 93 return self.generic_visit(node) 94 95 def visit_comprehension(self, node): 96 # We may be able to override some of these, but for now it's simpler 97 # to just assert that they're set. 98 self._ctx_override = None 99 return self.generic_visit(node) 100 101 def visit_Lambda(self, node): 102 # We may be able to override some of these, but for now it's simpler 103 # to just assert that they're set. 104 self._ctx_override = None 105 return self.generic_visit(node) 106 107 108class ReplaceTransformer(gast.NodeTransformer): 109 """Replace AST nodes.""" 110 111 def __init__(self, replacements): 112 """Create a new ReplaceTransformer. 113 114 Args: 115 replacements: A mapping from placeholder names to (lists of) AST nodes 116 that these placeholders will be replaced by. 117 """ 118 self.replacements = replacements 119 self.in_replacements = False 120 self.preserved_annos = { 121 anno.Basic.DIRECTIVES, 122 anno.Basic.EXTRA_LOOP_TEST, 123 anno.Basic.ORIGIN, 124 anno.Basic.SKIP_PROCESSING, 125 anno.Static.ORIG_DEFINITIONS, 126 'function_context_name', 127 } 128 129 def _prepare_replacement(self, replaced, key): 130 """Prepares a replacement AST that's safe to swap in for a node. 131 132 Args: 133 replaced: ast.AST, the node being replaced 134 key: Hashable, the key of the replacement AST 135 Returns: 136 ast.AST, the replacement AST 137 """ 138 repl = self.replacements[key] 139 140 new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos) 141 if isinstance(new_nodes, gast.AST): 142 new_nodes = [new_nodes] 143 144 return new_nodes 145 146 def visit_Expr(self, node): 147 # When replacing a placeholder with an entire statement, the replacement 148 # must stand on its own and not be wrapped in an Expr. 149 new_value = self.visit(node.value) 150 if new_value is node.value: 151 return node 152 return new_value 153 154 def visit_keyword(self, node): 155 if node.arg not in self.replacements: 156 return self.generic_visit(node) 157 158 repl = self._prepare_replacement(node, node.arg) 159 if isinstance(repl, gast.keyword): 160 return repl 161 elif (repl and isinstance(repl, (list, tuple)) and 162 all(isinstance(r, gast.keyword) for r in repl)): 163 return repl 164 # TODO(mdan): We may allow replacing with a string as well. 165 # For example, if one wanted to replace foo with bar in foo=baz, then 166 # we could allow changing just node arg, so that we end up with bar=baz. 167 raise ValueError( 168 'a keyword argument may only be replaced by another keyword or a ' 169 'non-empty list of keywords. Found: {} for keyword {}'.format( 170 repl, node.arg)) 171 172 def visit_FunctionDef(self, node): 173 node = self.generic_visit(node) 174 if node.name not in self.replacements: 175 return node 176 177 repl = self.replacements[node.name] 178 if not isinstance(repl, (gast.Name, ast.Name)): 179 raise ValueError( 180 'a function name can only be replaced by a Name node. Found: %s' % 181 repl) 182 node.name = repl.id 183 return node 184 185 def visit_Attribute(self, node): 186 node = self.generic_visit(node) 187 if node.attr not in self.replacements: 188 return node 189 190 repl = self.replacements[node.attr] 191 if not isinstance(repl, gast.Name): 192 raise ValueError( 193 'An attribute can only be replaced by a Name node. Found: %s' % repl) 194 node.attr = repl.id 195 return node 196 197 def visit_Name(self, node): 198 if node.id not in self.replacements: 199 return node 200 201 new_nodes = self._prepare_replacement(node, node.id) 202 203 if not new_nodes: 204 return new_nodes 205 206 # Preserve the target context. 207 adjuster = ContextAdjuster(type(node.ctx)) 208 for n in new_nodes: 209 if hasattr(n, 'ctx'): 210 adjuster.visit(n) 211 212 if len(new_nodes) == 1: 213 new_nodes, = new_nodes 214 215 return new_nodes 216 217 218def _convert_to_ast(n): 219 """Converts from a known data type to AST.""" 220 # Note: When generating AST nodes from strings/QNs in isolation, ctx is 221 # unknown. ctx must be filled in according to the template being used. 222 # See ReplaceTransformer.visit_Name. 223 if isinstance(n, str): 224 return gast.Name(id=n, ctx=None, annotation=None, type_comment=None) 225 if isinstance(n, qual_names.QN): 226 return n.ast() 227 if isinstance(n, list): 228 return [_convert_to_ast(e) for e in n] 229 if isinstance(n, tuple): 230 return tuple(_convert_to_ast(e) for e in n) 231 return n 232 233 234def replace(template, **replacements): 235 """Replaces placeholders in a Python template. 236 237 AST Name and Tuple nodes always receive the context that inferred from 238 the template. However, when replacing more complex nodes (that can potentially 239 contain Name children), then the caller is responsible for setting the 240 appropriate context. 241 242 Args: 243 template: A string representing Python code. Any symbol name can be used 244 that appears in the template code can be used as placeholder. 245 **replacements: A mapping from placeholder names to (lists of) AST nodes 246 that these placeholders will be replaced by. String values are also 247 supported as a shorthand for AST Name nodes with the respective ID. 248 249 Returns: 250 An AST node or list of AST nodes with the replacements made. If the 251 template was a function, a list will be returned. If the template was a 252 node, the same node will be returned. If the template was a string, an 253 AST node will be returned (a `Module` node in the case of a multi-line 254 string, an `Expr` node otherwise). 255 256 Raises: 257 ValueError: if the arguments are incorrect. 258 """ 259 if not isinstance(template, str): 260 raise ValueError('Expected string template, got %s' % type(template)) 261 for k in replacements: 262 replacements[k] = _convert_to_ast(replacements[k]) 263 template_str = parser.STANDARD_PREAMBLE + textwrap.dedent(template) 264 nodes = parser.parse( 265 template_str, 266 preamble_len=parser.STANDARD_PREAMBLE_LEN, 267 single_node=False) 268 results = [] 269 for node in nodes: 270 node = ReplaceTransformer(replacements).visit(node) 271 if isinstance(node, (list, tuple)): 272 results.extend(node) 273 else: 274 results.append(node) 275 results = [qual_names.resolve(r) for r in results] 276 return results 277 278 279def replace_as_expression(template, **replacements): 280 """Variant of replace that generates expressions, instead of code blocks.""" 281 replacement = replace(template, **replacements) 282 if len(replacement) != 1: 283 raise ValueError( 284 'single expression expected; for more general templates use replace') 285 node, = replacement 286 287 if isinstance(node, gast.Expr): 288 return node.value 289 elif isinstance(node, gast.Name): 290 return node 291 292 raise ValueError( 293 'the template is expected to generate an expression or a name node;' 294 ' instead found %s' % node) 295