1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""AST manipulation utilities.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import ast 22 23import gast 24 25from tensorflow.python.autograph.pyct import anno 26from tensorflow.python.autograph.pyct import parser 27from tensorflow.python.util import tf_inspect 28 29 30class CleanCopier(object): 31 """NodeTransformer-like visitor that copies an AST.""" 32 33 def __init__(self, preserve_annos): 34 super(CleanCopier, self).__init__() 35 self.preserve_annos = preserve_annos 36 37 def copy(self, node): 38 """Returns a deep copy of node (excluding some fields, see copy_clean).""" 39 40 if isinstance(node, list): 41 return [self.copy(n) for n in node] 42 elif isinstance(node, tuple): 43 return tuple(self.copy(n) for n in node) 44 elif not isinstance(node, (gast.AST, ast.AST)): 45 # Assuming everything that's not an AST, list or tuple is a value type 46 # and may simply be assigned. 47 return node 48 49 assert isinstance(node, (gast.AST, ast.AST)) 50 51 new_fields = {} 52 for f in node._fields: 53 if not f.startswith('__') and hasattr(node, f): 54 new_fields[f] = self.copy(getattr(node, f)) 55 new_node = type(node)(**new_fields) 56 57 if self.preserve_annos: 58 for k in self.preserve_annos: 59 anno.copyanno(node, new_node, k) 60 return new_node 61 62 63def copy_clean(node, preserve_annos=None): 64 """Creates a deep copy of an AST. 65 66 The copy will not include fields that are prefixed by '__', with the 67 exception of user-specified annotations. 68 69 Args: 70 node: ast.AST 71 preserve_annos: Optional[Set[Hashable]], annotation keys to include in the 72 copy 73 Returns: 74 ast.AST 75 """ 76 return CleanCopier(preserve_annos).copy(node) 77 78 79class SymbolRenamer(gast.NodeTransformer): 80 """Transformer that can rename symbols to a simple names.""" 81 82 def __init__(self, name_map): 83 self.name_map = name_map 84 85 def _process(self, node): 86 qn = anno.getanno(node, anno.Basic.QN) 87 if qn in self.name_map: 88 new_node = gast.Name(str(self.name_map[qn]), node.ctx, None) 89 # All annotations get carried over. 90 for k in anno.keys(node): 91 anno.copyanno(node, new_node, k) 92 return new_node 93 return self.generic_visit(node) 94 95 def visit_Name(self, node): 96 return self._process(node) 97 98 def visit_Attribute(self, node): 99 if anno.hasanno(node, anno.Basic.QN): 100 return self._process(node) 101 # Attributes of dynamic objects will not have a QN. 102 return self.generic_visit(node) 103 104 105def rename_symbols(node, name_map): 106 """Renames symbols in an AST. Requires qual_names annotations.""" 107 renamer = SymbolRenamer(name_map) 108 if isinstance(node, list): 109 return [renamer.visit(n) for n in node] 110 elif isinstance(node, tuple): 111 return tuple(renamer.visit(n) for n in node) 112 return renamer.visit(node) 113 114 115def keywords_to_dict(keywords): 116 """Converts a list of ast.keyword objects to a dict.""" 117 keys = [] 118 values = [] 119 for kw in keywords: 120 keys.append(gast.Str(kw.arg)) 121 values.append(kw.value) 122 return gast.Dict(keys=keys, values=values) 123 124 125class PatternMatcher(gast.NodeVisitor): 126 """Matches a node against a pattern represented by a node.""" 127 128 def __init__(self, pattern): 129 self.pattern = pattern 130 self.pattern_stack = [] 131 self.matches = True 132 133 def compare_and_visit(self, node, pattern): 134 self.pattern_stack.append(self.pattern) 135 self.pattern = pattern 136 self.generic_visit(node) 137 self.pattern = self.pattern_stack.pop() 138 139 def no_match(self): 140 self.matches = False 141 return False 142 143 def is_wildcard(self, p): 144 if isinstance(p, (list, tuple)) and len(p) == 1: 145 p, = p 146 if isinstance(p, gast.Name) and p.id == '_': 147 return True 148 if p == '_': 149 return True 150 return False 151 152 def generic_visit(self, node): 153 if not self.matches: 154 return 155 156 pattern = self.pattern 157 for f in node._fields: 158 if f.startswith('__'): 159 continue 160 161 if not hasattr(node, f): 162 if hasattr(pattern, f) and getattr(pattern, f): 163 return self.no_match() 164 else: 165 continue 166 if not hasattr(pattern, f): 167 return self.no_match() 168 169 v = getattr(node, f) 170 p = getattr(pattern, f) 171 172 if self.is_wildcard(p): 173 continue 174 if isinstance(v, (list, tuple)): 175 if not isinstance(p, (list, tuple)) or len(v) != len(p): 176 return self.no_match() 177 for v_item, p_item in zip(v, p): 178 self.compare_and_visit(v_item, p_item) 179 elif isinstance(v, (gast.AST, ast.AST)): 180 if not isinstance(v, type(p)) and not isinstance(p, type(v)): 181 return self.no_match() 182 self.compare_and_visit(v, p) 183 else: 184 # Assume everything else is a value type. 185 if v != p: 186 return self.no_match() 187 188 189def matches(node, pattern): 190 """Basic pattern matcher for AST. 191 192 The pattern may contain wildcards represented by the symbol '_'. A node 193 matches a pattern if for every node in the tree, either there is a node of 194 the same type in pattern, or a Name node with id='_'. 195 196 Args: 197 node: ast.AST 198 pattern: ast.AST 199 Returns: 200 bool 201 """ 202 if isinstance(pattern, str): 203 pattern, = parser.parse_str(pattern).body 204 205 matcher = PatternMatcher(pattern) 206 matcher.visit(node) 207 return matcher.matches 208 209 210# TODO(mdan): Once we have error tracing, we may be able to just go to SSA. 211def apply_to_single_assignments(targets, values, apply_fn): 212 """Applies a function to each individual assignment. 213 214 This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. 215 It tries to break down the unpacking if possible. In effect, it has the same 216 effect as passing the assigned values in SSA form to apply_fn. 217 218 Examples: 219 220 The following will result in apply_fn(a, c), apply_fn(b, d): 221 222 a, b = c, d 223 224 The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): 225 226 a, b = c 227 228 The following will result in apply_fn(a, (b, c)): 229 230 a = b, c 231 232 It uses the visitor pattern to allow subclasses to process single 233 assignments individually. 234 235 Args: 236 targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be 237 used with the targets field of an ast.Assign node 238 values: ast.AST 239 apply_fn: Callable[[ast.AST, ast.AST], None], called with the 240 respective nodes of each single assignment 241 """ 242 if not isinstance(targets, (list, tuple)): 243 targets = (targets,) 244 for target in targets: 245 if isinstance(target, (gast.Tuple, gast.List)): 246 for i in range(len(target.elts)): 247 target_el = target.elts[i] 248 if isinstance(values, (gast.Tuple, gast.List)): 249 value_el = values.elts[i] 250 else: 251 idx = parser.parse_expression(str(i)) 252 value_el = gast.Subscript(values, gast.Index(idx), ctx=gast.Load()) 253 apply_to_single_assignments(target_el, value_el, apply_fn) 254 else: 255 apply_fn(target, values) 256 257 258def parallel_walk(node, other): 259 """Walks two ASTs in parallel. 260 261 The two trees must have identical structure. 262 263 Args: 264 node: Union[ast.AST, Iterable[ast.AST]] 265 other: Union[ast.AST, Iterable[ast.AST]] 266 Yields: 267 Tuple[ast.AST, ast.AST] 268 Raises: 269 ValueError: if the two trees don't have identical structure. 270 """ 271 if isinstance(node, (list, tuple)): 272 node_stack = list(node) 273 else: 274 node_stack = [node] 275 276 if isinstance(other, (list, tuple)): 277 other_stack = list(other) 278 else: 279 other_stack = [other] 280 281 while node_stack and other_stack: 282 assert len(node_stack) == len(other_stack) 283 n = node_stack.pop() 284 o = other_stack.pop() 285 286 if (not isinstance(n, (ast.AST, gast.AST, str)) or 287 not isinstance(o, (ast.AST, gast.AST, str)) or 288 n.__class__.__name__ != o.__class__.__name__): 289 raise ValueError('inconsistent nodes: {} ({}) and {} ({})'.format( 290 n, n.__class__.__name__, o, o.__class__.__name__)) 291 292 yield n, o 293 294 if isinstance(n, str): 295 assert isinstance(o, str), 'The check above should have ensured this' 296 continue 297 298 for f in n._fields: 299 n_child = getattr(n, f, None) 300 o_child = getattr(o, f, None) 301 if f.startswith('__') or n_child is None or o_child is None: 302 continue 303 304 if isinstance(n_child, (list, tuple)): 305 if (not isinstance(o_child, (list, tuple)) or 306 len(n_child) != len(o_child)): 307 raise ValueError( 308 'inconsistent values for field {}: {} and {}'.format( 309 f, n_child, o_child)) 310 node_stack.extend(n_child) 311 other_stack.extend(o_child) 312 313 elif isinstance(n_child, (gast.AST, ast.AST)): 314 node_stack.append(n_child) 315 other_stack.append(o_child) 316 317 elif n_child != o_child: 318 raise ValueError( 319 'inconsistent values for field {}: {} and {}'.format( 320 f, n_child, o_child)) 321 322 323class LambdaDefinitionMatcher(gast.NodeVisitor): 324 """Finds lambda nodes that match a given lambda's signature.""" 325 326 def __init__(self, fn): 327 self.fn = fn 328 self.matching_nodes = [] 329 330 def _arg_name(self, node): 331 if node is None: 332 return None 333 if isinstance(node, gast.Name): 334 return node.id 335 assert isinstance(node, str) 336 return node 337 338 def _argspec_matches(self, node): 339 arg_spec = tf_inspect.getfullargspec(self.fn) 340 341 node_args = tuple(self._arg_name(arg) for arg in node.args.args) 342 if node_args != tuple(arg_spec.args): 343 return False 344 345 if arg_spec.varargs != self._arg_name(node.args.vararg): 346 return False 347 348 if arg_spec.varkw != self._arg_name(node.args.kwarg): 349 return False 350 351 node_kwonlyargs = tuple(self._arg_name(arg) for arg in node.args.kwonlyargs) 352 if node_kwonlyargs != tuple(arg_spec.kwonlyargs): 353 return False 354 355 return True 356 357 def visit_Lambda(self, node): 358 self.generic_visit(node) 359 360 if self.fn.__name__ != '<lambda>': 361 return 362 if not self._argspec_matches(node): 363 return 364 365 self.matching_nodes.append(node) 366 367 368def find_matching_definitions(node, f): 369 matcher = LambdaDefinitionMatcher(f) 370 matcher.visit(node) 371 return tuple(matcher.matching_nodes) 372