• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AST manipulation utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import ast
22
23import gast
24
25from tensorflow.python.autograph.pyct import anno
26from tensorflow.python.autograph.pyct import parser
27from tensorflow.python.autograph.pyct import qual_names
28
29
30class CleanCopier(object):
31  """NodeTransformer-like visitor that copies an AST."""
32
33  def __init__(self, preserve_annos):
34    super(CleanCopier, self).__init__()
35    self.preserve_annos = preserve_annos
36
37  def copy(self, node):
38    """Returns a deep copy of node (excluding some fields, see copy_clean)."""
39
40    if isinstance(node, list):
41      return [self.copy(n) for n in node]
42    elif isinstance(node, tuple):
43      return tuple(self.copy(n) for n in node)
44    elif not isinstance(node, (gast.AST, ast.AST)):
45      # Assuming everything that's not an AST, list or tuple is a value type
46      # and may simply be assigned.
47      return node
48
49    assert isinstance(node, (gast.AST, ast.AST))
50
51    new_fields = {}
52    for f in node._fields:
53      if not f.startswith('__') and hasattr(node, f):
54        new_fields[f] = self.copy(getattr(node, f))
55    new_node = type(node)(**new_fields)
56
57    if self.preserve_annos:
58      for k in self.preserve_annos:
59        anno.copyanno(node, new_node, k)
60    return new_node
61
62
63def copy_clean(node, preserve_annos=None):
64  """Creates a deep copy of an AST.
65
66  The copy will not include fields that are prefixed by '__', with the
67  exception of user-specified annotations.
68
69  Args:
70    node: ast.AST
71    preserve_annos: Optional[Set[Hashable]], annotation keys to include in the
72        copy
73  Returns:
74    ast.AST
75  """
76  return CleanCopier(preserve_annos).copy(node)
77
78
79class SymbolRenamer(gast.NodeTransformer):
80  """Transformer that can rename symbols to a simple names."""
81
82  def __init__(self, name_map):
83    self.name_map = name_map
84
85  def _process_name_node(self, node):
86    qn = anno.getanno(node, anno.Basic.QN)
87    if qn in self.name_map:
88      new_node = gast.Name(
89          str(self.name_map[qn]),
90          ctx=node.ctx,
91          annotation=None,
92          type_comment=None)
93      # All annotations get carried over.
94      for k in anno.keys(node):
95        anno.copyanno(node, new_node, k)
96      return new_node
97    return self.generic_visit(node)
98
99  def _process_list_of_strings(self, names):
100    for i in range(len(names)):
101      qn = qual_names.QN(names[i])
102      if qn in self.name_map:
103        names[i] = str(self.name_map[qn])
104    return names
105
106  def visit_Nonlocal(self, node):
107    node.names = self._process_list_of_strings(node.names)
108    return node
109
110  def visit_Global(self, node):
111    node.names = self._process_list_of_strings(node.names)
112    return node
113
114  def visit_Name(self, node):
115    return self._process_name_node(node)
116
117  def visit_Attribute(self, node):
118    if anno.hasanno(node, anno.Basic.QN):
119      return self._process_name_node(node)
120    # Renaming attributes is not supported.
121    return self.generic_visit(node)
122
123  def visit_FunctionDef(self, node):
124    qn = qual_names.QN(node.name)
125    if qn in self.name_map:
126      node.name = str(self.name_map[qn])
127    return self.generic_visit(node)
128
129
130def rename_symbols(node, name_map):
131  """Renames symbols in an AST. Requires qual_names annotations."""
132  renamer = SymbolRenamer(name_map)
133  if isinstance(node, list):
134    return [renamer.visit(n) for n in node]
135  elif isinstance(node, tuple):
136    return tuple(renamer.visit(n) for n in node)
137  return renamer.visit(node)
138
139
140def keywords_to_dict(keywords):
141  """Converts a list of ast.keyword objects to a dict."""
142  keys = []
143  values = []
144  for kw in keywords:
145    keys.append(gast.Constant(kw.arg, kind=None))
146    values.append(kw.value)
147  return gast.Dict(keys=keys, values=values)
148
149
150class PatternMatcher(gast.NodeVisitor):
151  """Matches a node against a pattern represented by a node."""
152
153  def __init__(self, pattern):
154    self.pattern = pattern
155    self.pattern_stack = []
156    self.matches = True
157
158  def compare_and_visit(self, node, pattern):
159    self.pattern_stack.append(self.pattern)
160    self.pattern = pattern
161    self.generic_visit(node)
162    self.pattern = self.pattern_stack.pop()
163
164  def no_match(self):
165    self.matches = False
166    return False
167
168  def is_wildcard(self, p):
169    if isinstance(p, (list, tuple)) and len(p) == 1:
170      p, = p
171    if isinstance(p, gast.Name) and p.id == '_':
172      return True
173    if p == '_':
174      return True
175    return False
176
177  def generic_visit(self, node):
178    if not self.matches:
179      return
180
181    pattern = self.pattern
182    for f in node._fields:
183      if f.startswith('__'):
184        continue
185
186      if not hasattr(node, f):
187        if hasattr(pattern, f) and getattr(pattern, f):
188          return self.no_match()
189        else:
190          continue
191      if not hasattr(pattern, f):
192        return self.no_match()
193
194      v = getattr(node, f)
195      p = getattr(pattern, f)
196
197      if self.is_wildcard(p):
198        continue
199      if isinstance(v, (list, tuple)):
200        if not isinstance(p, (list, tuple)) or len(v) != len(p):
201          return self.no_match()
202        for v_item, p_item in zip(v, p):
203          self.compare_and_visit(v_item, p_item)
204      elif isinstance(v, (gast.AST, ast.AST)):
205        if not isinstance(v, type(p)) and not isinstance(p, type(v)):
206          return self.no_match()
207        self.compare_and_visit(v, p)
208      else:
209        # Assume everything else is a value type.
210        if v != p:
211          return self.no_match()
212
213
214def matches(node, pattern):
215  """Basic pattern matcher for AST.
216
217  The pattern may contain wildcards represented by the symbol '_'. A node
218  matches a pattern if for every node in the tree, either there is a node of
219  the same type in pattern, or a Name node with id='_'.
220
221  Args:
222    node: ast.AST
223    pattern: ast.AST
224  Returns:
225    bool
226  """
227  if isinstance(pattern, str):
228    pattern = parser.parse_str(pattern)
229
230  matcher = PatternMatcher(pattern)
231  matcher.visit(node)
232  return matcher.matches
233
234
235# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
236def apply_to_single_assignments(targets, values, apply_fn):
237  """Applies a function to each individual assignment.
238
239  This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
240  It tries to break down the unpacking if possible. In effect, it has the same
241  effect as passing the assigned values in SSA form to apply_fn.
242
243  Examples:
244
245  The following will result in apply_fn(a, c), apply_fn(b, d):
246
247      a, b = c, d
248
249  The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
250
251      a, b = c
252
253  The following will result in apply_fn(a, (b, c)):
254
255      a = b, c
256
257  It uses the visitor pattern to allow subclasses to process single
258  assignments individually.
259
260  Args:
261    targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
262        used with the targets field of an ast.Assign node
263    values: ast.AST
264    apply_fn: Callable[[ast.AST, ast.AST], None], called with the
265        respective nodes of each single assignment
266  """
267  if not isinstance(targets, (list, tuple)):
268    targets = (targets,)
269  for target in targets:
270    if isinstance(target, (gast.Tuple, gast.List)):
271      for i in range(len(target.elts)):
272        target_el = target.elts[i]
273        if isinstance(values, (gast.Tuple, gast.List)):
274          value_el = values.elts[i]
275        else:
276          idx = parser.parse_expression(str(i))
277          value_el = gast.Subscript(values, idx, ctx=gast.Load())
278        apply_to_single_assignments(target_el, value_el, apply_fn)
279    else:
280      apply_fn(target, values)
281
282
283def parallel_walk(node, other):
284  """Walks two ASTs in parallel.
285
286  The two trees must have identical structure.
287
288  Args:
289    node: Union[ast.AST, Iterable[ast.AST]]
290    other: Union[ast.AST, Iterable[ast.AST]]
291  Yields:
292    Tuple[ast.AST, ast.AST]
293  Raises:
294    ValueError: if the two trees don't have identical structure.
295  """
296  if isinstance(node, (list, tuple)):
297    node_stack = list(node)
298  else:
299    node_stack = [node]
300
301  if isinstance(other, (list, tuple)):
302    other_stack = list(other)
303  else:
304    other_stack = [other]
305
306  while node_stack and other_stack:
307    assert len(node_stack) == len(other_stack)
308    n = node_stack.pop()
309    o = other_stack.pop()
310
311    if ((not isinstance(n, (ast.AST, gast.AST, str)) and n is not None) or
312        (not isinstance(o, (ast.AST, gast.AST, str)) and n is not None) or
313        n.__class__.__name__ != o.__class__.__name__):
314      raise ValueError('inconsistent nodes: {} ({}) and {} ({})'.format(
315          n, n.__class__.__name__, o, o.__class__.__name__))
316
317    yield n, o
318
319    if isinstance(n, str):
320      assert isinstance(o, str), 'The check above should have ensured this'
321      continue
322    if n is None:
323      assert o is None, 'The check above should have ensured this'
324      continue
325
326    for f in n._fields:
327      n_child = getattr(n, f, None)
328      o_child = getattr(o, f, None)
329      if f.startswith('__') or n_child is None or o_child is None:
330        continue
331
332      if isinstance(n_child, (list, tuple)):
333        if (not isinstance(o_child, (list, tuple)) or
334            len(n_child) != len(o_child)):
335          raise ValueError(
336              'inconsistent values for field {}: {} and {}'.format(
337                  f, n_child, o_child))
338        node_stack.extend(n_child)
339        other_stack.extend(o_child)
340
341      elif isinstance(n_child, (gast.AST, ast.AST)):
342        node_stack.append(n_child)
343        other_stack.append(o_child)
344
345      elif n_child != o_child:
346        raise ValueError(
347            'inconsistent values for field {}: {} and {}'.format(
348                f, n_child, o_child))
349