• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AST conversion templates.
16
17Adapted from Tangent.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import ast
25import textwrap
26
27import gast
28
29from tensorflow.python.autograph.pyct import anno
30from tensorflow.python.autograph.pyct import ast_util
31from tensorflow.python.autograph.pyct import parser
32from tensorflow.python.autograph.pyct import qual_names
33
34
35class ContextAdjuster(gast.NodeTransformer):
36  """Adjusts the ctx field of nodes to ensure consistency.
37
38  This transformer can change the ctx fields of a variable, tuple and other
39  AST elements that allow one, based on whether the element is being read or
40  written.
41  """
42
43  def __init__(self, override_value):
44    self._ctx_override = override_value
45
46  def visit(self, node):
47    original_override = self._ctx_override
48    node = super(ContextAdjuster, self).visit(node)
49    if hasattr(node, 'ctx'):
50      assert node.ctx is not None, 'node {} has ctx unset'.format(node)
51    self._ctx_override = original_override
52    return node
53
54  def _apply_override(self, node):
55    if self._ctx_override is not None:
56      node.ctx = self._ctx_override()
57
58  def visit_Attribute(self, node):
59    self._apply_override(node)
60    self._ctx_override = gast.Load
61    node = self.generic_visit(node)
62    return node
63
64  def visit_Tuple(self, node):
65    self._apply_override(node)
66    return self.generic_visit(node)
67
68  def visit_List(self, node):
69    self._apply_override(node)
70    return self.generic_visit(node)
71
72  def visit_Name(self, node):
73    self._apply_override(node)
74    return self.generic_visit(node)
75
76  def visit_Call(self, node):
77    self._apply_override(node)
78    # We may be able to override these to Load(), but for now it's simpler
79    # to just assert that they're set.
80    self._ctx_override = None
81    return self.generic_visit(node)
82
83  def visit_Dict(self, node):
84    # We may be able to override these to Load(), but for now it's simpler
85    # to just assert that they're set.
86    self._ctx_override = None
87    return self.generic_visit(node)
88
89  def visit_Subscript(self, node):
90    self._apply_override(node)
91    self._ctx_override = gast.Load
92    node.value = self.visit(node.value)
93    return self.generic_visit(node)
94
95  def visit_comprehension(self, node):
96    # We may be able to override some of these, but for now it's simpler
97    # to just assert that they're set.
98    self._ctx_override = None
99    return self.generic_visit(node)
100
101  def visit_Lambda(self, node):
102    # We may be able to override some of these, but for now it's simpler
103    # to just assert that they're set.
104    self._ctx_override = None
105    return self.generic_visit(node)
106
107
108class ReplaceTransformer(gast.NodeTransformer):
109  """Replace AST nodes."""
110
111  def __init__(self, replacements):
112    """Create a new ReplaceTransformer.
113
114    Args:
115      replacements: A mapping from placeholder names to (lists of) AST nodes
116          that these placeholders will be replaced by.
117    """
118    self.replacements = replacements
119    self.in_replacements = False
120    self.preserved_annos = {
121        anno.Basic.DIRECTIVES,
122        anno.Basic.EXTRA_LOOP_TEST,
123        anno.Basic.ORIGIN,
124        anno.Basic.SKIP_PROCESSING,
125        anno.Static.ORIG_DEFINITIONS,
126        'function_context_name',
127    }
128
129  def _prepare_replacement(self, replaced, key):
130    """Prepares a replacement AST that's safe to swap in for a node.
131
132    Args:
133      replaced: ast.AST, the node being replaced
134      key: Hashable, the key of the replacement AST
135    Returns:
136      ast.AST, the replacement AST
137    """
138    repl = self.replacements[key]
139
140    new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos)
141    if isinstance(new_nodes, gast.AST):
142      new_nodes = [new_nodes]
143
144    return new_nodes
145
146  def visit_Expr(self, node):
147    # When replacing a placeholder with an entire statement, the replacement
148    # must stand on its own and not be wrapped in an Expr.
149    new_value = self.visit(node.value)
150    if new_value is node.value:
151      return node
152    return new_value
153
154  def visit_keyword(self, node):
155    if node.arg not in self.replacements:
156      return self.generic_visit(node)
157
158    repl = self._prepare_replacement(node, node.arg)
159    if isinstance(repl, gast.keyword):
160      return repl
161    elif (repl and isinstance(repl, (list, tuple)) and
162          all(isinstance(r, gast.keyword) for r in repl)):
163      return repl
164    # TODO(mdan): We may allow replacing with a string as well.
165    # For example, if one wanted to replace foo with bar in foo=baz, then
166    # we could allow changing just node arg, so that we end up with bar=baz.
167    raise ValueError(
168        'a keyword argument may only be replaced by another keyword or a '
169        'non-empty list of keywords. Found: {} for keyword {}'.format(
170            repl, node.arg))
171
172  def visit_FunctionDef(self, node):
173    node = self.generic_visit(node)
174    if node.name not in self.replacements:
175      return node
176
177    repl = self.replacements[node.name]
178    if not isinstance(repl, (gast.Name, ast.Name)):
179      raise ValueError(
180          'a function name can only be replaced by a Name node. Found: %s' %
181          repl)
182    node.name = repl.id
183    return node
184
185  def visit_Attribute(self, node):
186    node = self.generic_visit(node)
187    if node.attr not in self.replacements:
188      return node
189
190    repl = self.replacements[node.attr]
191    if not isinstance(repl, gast.Name):
192      raise ValueError(
193          'An attribute can only be replaced by a Name node. Found: %s' % repl)
194    node.attr = repl.id
195    return node
196
197  def visit_Name(self, node):
198    if node.id not in self.replacements:
199      return node
200
201    new_nodes = self._prepare_replacement(node, node.id)
202
203    if not new_nodes:
204      return new_nodes
205
206    # Preserve the target context.
207    adjuster = ContextAdjuster(type(node.ctx))
208    for n in new_nodes:
209      if hasattr(n, 'ctx'):
210        adjuster.visit(n)
211
212    if len(new_nodes) == 1:
213      new_nodes, = new_nodes
214
215    return new_nodes
216
217
218def _convert_to_ast(n):
219  """Converts from a known data type to AST."""
220  # Note: When generating AST nodes from strings/QNs in isolation, ctx is
221  # unknown. ctx must be filled in according to the template being used.
222  # See ReplaceTransformer.visit_Name.
223  if isinstance(n, str):
224    return gast.Name(id=n, ctx=None, annotation=None, type_comment=None)
225  if isinstance(n, qual_names.QN):
226    return n.ast()
227  if isinstance(n, list):
228    return [_convert_to_ast(e) for e in n]
229  if isinstance(n, tuple):
230    return tuple(_convert_to_ast(e) for e in n)
231  return n
232
233
234def replace(template, **replacements):
235  """Replaces placeholders in a Python template.
236
237  AST Name and Tuple nodes always receive the context that inferred from
238  the template. However, when replacing more complex nodes (that can potentially
239  contain Name children), then the caller is responsible for setting the
240  appropriate context.
241
242  Args:
243    template: A string representing Python code. Any symbol name can be used
244        that appears in the template code can be used as placeholder.
245    **replacements: A mapping from placeholder names to (lists of) AST nodes
246        that these placeholders will be replaced by. String values are also
247        supported as a shorthand for AST Name nodes with the respective ID.
248
249  Returns:
250    An AST node or list of AST nodes with the replacements made. If the
251    template was a function, a list will be returned. If the template was a
252    node, the same node will be returned. If the template was a string, an
253    AST node will be returned (a `Module` node in the case of a multi-line
254    string, an `Expr` node otherwise).
255
256  Raises:
257    ValueError: if the arguments are incorrect.
258  """
259  if not isinstance(template, str):
260    raise ValueError('Expected string template, got %s' % type(template))
261  for k in replacements:
262    replacements[k] = _convert_to_ast(replacements[k])
263  template_str = parser.STANDARD_PREAMBLE + textwrap.dedent(template)
264  nodes = parser.parse(
265      template_str,
266      preamble_len=parser.STANDARD_PREAMBLE_LEN,
267      single_node=False)
268  results = []
269  for node in nodes:
270    node = ReplaceTransformer(replacements).visit(node)
271    if isinstance(node, (list, tuple)):
272      results.extend(node)
273    else:
274      results.append(node)
275  results = [qual_names.resolve(r) for r in results]
276  return results
277
278
279def replace_as_expression(template, **replacements):
280  """Variant of replace that generates expressions, instead of code blocks."""
281  replacement = replace(template, **replacements)
282  if len(replacement) != 1:
283    raise ValueError(
284        'single expression expected; for more general templates use replace')
285  node, = replacement
286
287  if isinstance(node, gast.Expr):
288    return node.value
289  elif isinstance(node, gast.Name):
290    return node
291
292  raise ValueError(
293      'the template is expected to generate an expression or a name node;'
294      ' instead found %s' % node)
295