1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""AST manipulation utilities.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import ast 22 23import gast 24 25from tensorflow.python.autograph.pyct import anno 26from tensorflow.python.autograph.pyct import parser 27from tensorflow.python.autograph.pyct import qual_names 28 29 30class CleanCopier(object): 31 """NodeTransformer-like visitor that copies an AST.""" 32 33 def __init__(self, preserve_annos): 34 super(CleanCopier, self).__init__() 35 self.preserve_annos = preserve_annos 36 37 def copy(self, node): 38 """Returns a deep copy of node (excluding some fields, see copy_clean).""" 39 40 if isinstance(node, list): 41 return [self.copy(n) for n in node] 42 elif isinstance(node, tuple): 43 return tuple(self.copy(n) for n in node) 44 elif not isinstance(node, (gast.AST, ast.AST)): 45 # Assuming everything that's not an AST, list or tuple is a value type 46 # and may simply be assigned. 47 return node 48 49 assert isinstance(node, (gast.AST, ast.AST)) 50 51 new_fields = {} 52 for f in node._fields: 53 if not f.startswith('__') and hasattr(node, f): 54 new_fields[f] = self.copy(getattr(node, f)) 55 new_node = type(node)(**new_fields) 56 57 if self.preserve_annos: 58 for k in self.preserve_annos: 59 anno.copyanno(node, new_node, k) 60 return new_node 61 62 63def copy_clean(node, preserve_annos=None): 64 """Creates a deep copy of an AST. 65 66 The copy will not include fields that are prefixed by '__', with the 67 exception of user-specified annotations. 68 69 Args: 70 node: ast.AST 71 preserve_annos: Optional[Set[Hashable]], annotation keys to include in the 72 copy 73 Returns: 74 ast.AST 75 """ 76 return CleanCopier(preserve_annos).copy(node) 77 78 79class SymbolRenamer(gast.NodeTransformer): 80 """Transformer that can rename symbols to a simple names.""" 81 82 def __init__(self, name_map): 83 self.name_map = name_map 84 85 def _process_name_node(self, node): 86 qn = anno.getanno(node, anno.Basic.QN) 87 if qn in self.name_map: 88 new_node = gast.Name( 89 str(self.name_map[qn]), 90 ctx=node.ctx, 91 annotation=None, 92 type_comment=None) 93 # All annotations get carried over. 94 for k in anno.keys(node): 95 anno.copyanno(node, new_node, k) 96 return new_node 97 return self.generic_visit(node) 98 99 def _process_list_of_strings(self, names): 100 for i in range(len(names)): 101 qn = qual_names.QN(names[i]) 102 if qn in self.name_map: 103 names[i] = str(self.name_map[qn]) 104 return names 105 106 def visit_Nonlocal(self, node): 107 node.names = self._process_list_of_strings(node.names) 108 return node 109 110 def visit_Global(self, node): 111 node.names = self._process_list_of_strings(node.names) 112 return node 113 114 def visit_Name(self, node): 115 return self._process_name_node(node) 116 117 def visit_Attribute(self, node): 118 if anno.hasanno(node, anno.Basic.QN): 119 return self._process_name_node(node) 120 # Renaming attributes is not supported. 121 return self.generic_visit(node) 122 123 def visit_FunctionDef(self, node): 124 qn = qual_names.QN(node.name) 125 if qn in self.name_map: 126 node.name = str(self.name_map[qn]) 127 return self.generic_visit(node) 128 129 130def rename_symbols(node, name_map): 131 """Renames symbols in an AST. Requires qual_names annotations.""" 132 renamer = SymbolRenamer(name_map) 133 if isinstance(node, list): 134 return [renamer.visit(n) for n in node] 135 elif isinstance(node, tuple): 136 return tuple(renamer.visit(n) for n in node) 137 return renamer.visit(node) 138 139 140def keywords_to_dict(keywords): 141 """Converts a list of ast.keyword objects to a dict.""" 142 keys = [] 143 values = [] 144 for kw in keywords: 145 keys.append(gast.Constant(kw.arg, kind=None)) 146 values.append(kw.value) 147 return gast.Dict(keys=keys, values=values) 148 149 150class PatternMatcher(gast.NodeVisitor): 151 """Matches a node against a pattern represented by a node.""" 152 153 def __init__(self, pattern): 154 self.pattern = pattern 155 self.pattern_stack = [] 156 self.matches = True 157 158 def compare_and_visit(self, node, pattern): 159 self.pattern_stack.append(self.pattern) 160 self.pattern = pattern 161 self.generic_visit(node) 162 self.pattern = self.pattern_stack.pop() 163 164 def no_match(self): 165 self.matches = False 166 return False 167 168 def is_wildcard(self, p): 169 if isinstance(p, (list, tuple)) and len(p) == 1: 170 p, = p 171 if isinstance(p, gast.Name) and p.id == '_': 172 return True 173 if p == '_': 174 return True 175 return False 176 177 def generic_visit(self, node): 178 if not self.matches: 179 return 180 181 pattern = self.pattern 182 for f in node._fields: 183 if f.startswith('__'): 184 continue 185 186 if not hasattr(node, f): 187 if hasattr(pattern, f) and getattr(pattern, f): 188 return self.no_match() 189 else: 190 continue 191 if not hasattr(pattern, f): 192 return self.no_match() 193 194 v = getattr(node, f) 195 p = getattr(pattern, f) 196 197 if self.is_wildcard(p): 198 continue 199 if isinstance(v, (list, tuple)): 200 if not isinstance(p, (list, tuple)) or len(v) != len(p): 201 return self.no_match() 202 for v_item, p_item in zip(v, p): 203 self.compare_and_visit(v_item, p_item) 204 elif isinstance(v, (gast.AST, ast.AST)): 205 if not isinstance(v, type(p)) and not isinstance(p, type(v)): 206 return self.no_match() 207 self.compare_and_visit(v, p) 208 else: 209 # Assume everything else is a value type. 210 if v != p: 211 return self.no_match() 212 213 214def matches(node, pattern): 215 """Basic pattern matcher for AST. 216 217 The pattern may contain wildcards represented by the symbol '_'. A node 218 matches a pattern if for every node in the tree, either there is a node of 219 the same type in pattern, or a Name node with id='_'. 220 221 Args: 222 node: ast.AST 223 pattern: ast.AST 224 Returns: 225 bool 226 """ 227 if isinstance(pattern, str): 228 pattern = parser.parse_str(pattern) 229 230 matcher = PatternMatcher(pattern) 231 matcher.visit(node) 232 return matcher.matches 233 234 235# TODO(mdan): Once we have error tracing, we may be able to just go to SSA. 236def apply_to_single_assignments(targets, values, apply_fn): 237 """Applies a function to each individual assignment. 238 239 This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. 240 It tries to break down the unpacking if possible. In effect, it has the same 241 effect as passing the assigned values in SSA form to apply_fn. 242 243 Examples: 244 245 The following will result in apply_fn(a, c), apply_fn(b, d): 246 247 a, b = c, d 248 249 The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): 250 251 a, b = c 252 253 The following will result in apply_fn(a, (b, c)): 254 255 a = b, c 256 257 It uses the visitor pattern to allow subclasses to process single 258 assignments individually. 259 260 Args: 261 targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be 262 used with the targets field of an ast.Assign node 263 values: ast.AST 264 apply_fn: Callable[[ast.AST, ast.AST], None], called with the 265 respective nodes of each single assignment 266 """ 267 if not isinstance(targets, (list, tuple)): 268 targets = (targets,) 269 for target in targets: 270 if isinstance(target, (gast.Tuple, gast.List)): 271 for i in range(len(target.elts)): 272 target_el = target.elts[i] 273 if isinstance(values, (gast.Tuple, gast.List)): 274 value_el = values.elts[i] 275 else: 276 idx = parser.parse_expression(str(i)) 277 value_el = gast.Subscript(values, idx, ctx=gast.Load()) 278 apply_to_single_assignments(target_el, value_el, apply_fn) 279 else: 280 apply_fn(target, values) 281 282 283def parallel_walk(node, other): 284 """Walks two ASTs in parallel. 285 286 The two trees must have identical structure. 287 288 Args: 289 node: Union[ast.AST, Iterable[ast.AST]] 290 other: Union[ast.AST, Iterable[ast.AST]] 291 Yields: 292 Tuple[ast.AST, ast.AST] 293 Raises: 294 ValueError: if the two trees don't have identical structure. 295 """ 296 if isinstance(node, (list, tuple)): 297 node_stack = list(node) 298 else: 299 node_stack = [node] 300 301 if isinstance(other, (list, tuple)): 302 other_stack = list(other) 303 else: 304 other_stack = [other] 305 306 while node_stack and other_stack: 307 assert len(node_stack) == len(other_stack) 308 n = node_stack.pop() 309 o = other_stack.pop() 310 311 if ((not isinstance(n, (ast.AST, gast.AST, str)) and n is not None) or 312 (not isinstance(o, (ast.AST, gast.AST, str)) and n is not None) or 313 n.__class__.__name__ != o.__class__.__name__): 314 raise ValueError('inconsistent nodes: {} ({}) and {} ({})'.format( 315 n, n.__class__.__name__, o, o.__class__.__name__)) 316 317 yield n, o 318 319 if isinstance(n, str): 320 assert isinstance(o, str), 'The check above should have ensured this' 321 continue 322 if n is None: 323 assert o is None, 'The check above should have ensured this' 324 continue 325 326 for f in n._fields: 327 n_child = getattr(n, f, None) 328 o_child = getattr(o, f, None) 329 if f.startswith('__') or n_child is None or o_child is None: 330 continue 331 332 if isinstance(n_child, (list, tuple)): 333 if (not isinstance(o_child, (list, tuple)) or 334 len(n_child) != len(o_child)): 335 raise ValueError( 336 'inconsistent values for field {}: {} and {}'.format( 337 f, n_child, o_child)) 338 node_stack.extend(n_child) 339 other_stack.extend(o_child) 340 341 elif isinstance(n_child, (gast.AST, ast.AST)): 342 node_stack.append(n_child) 343 other_stack.append(o_child) 344 345 elif n_child != o_child: 346 raise ValueError( 347 'inconsistent values for field {}: {} and {}'.format( 348 f, n_child, o_child)) 349