• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A node transformer that includes utilities for SCT."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21
22import gast
23
24from tensorflow.python.autograph.pyct import anno
25from tensorflow.python.autograph.pyct import compiler
26from tensorflow.python.autograph.pyct import pretty_printer
27from tensorflow.python.autograph.pyct import templates
28
29
30# TODO(znado): Use namedtuple.
31class Context(object):
32  """Contains information about a source code transformation.
33
34  This object is mutable, and is updated during conversion. Not thread safe.
35
36  Attributes:
37    info: EntityInfo, immutable.
38    current_origin: origin_info.OriginInfo, holds the OriginInfo of the last
39      AST node to be processed successfully. Useful for error handling.
40  """
41
42  def __init__(self, info):
43    self.info = info
44    self.current_origin = None
45
46
47# TODO(mdan): Use namedtuple.
48class EntityInfo(object):
49  """Contains information about a Python entity.
50
51  Immutable.
52
53  Examples of entities include functions and classes.
54
55  Attributes:
56    source_code: The entity's source code.
57    source_file: The entity's source file.
58    namespace: Dict[str, ], containing symbols visible to the entity (excluding
59      parameters).
60    arg_values: dict[str->*], containing parameter values, if known.
61    arg_types: dict[str->*], containing parameter types, if known.
62  """
63
64  # TODO(mdan): Remove the default and update tests.
65  def __init__(self, source_code, source_file, namespace, arg_values,
66               arg_types):
67    self.source_code = source_code
68    self.source_file = source_file
69    self.namespace = namespace
70    self.arg_values = {} if arg_values is None else arg_values
71    self.arg_types = {} if arg_types is None else arg_types
72
73
74class _StateStack(object):
75  """Typed stack abstraction.
76
77  This class provides syntactic sugar for a stack of objects of known
78  type. It allows accessing attributes of the object at the top of the stack
79  directly against this object, which allows for very terse syntax.
80
81  For example, this code:
82
83    stack = _StateStack(Foo)
84    stack.enter()
85    stack.bar
86
87  Is equivalent to:
88
89    stack = []
90    stack.append(Foo())
91    foo = stack[-1]
92    foo.bar
93
94  See _State for more on how this is used.
95
96  Attributes:
97    type: Any, the type of objects that this stack holds
98    level: int, the current stack depth
99    value: Any, the instance of the object at the top of the stack
100  """
101
102  def __init__(self, type_):
103    # Because we override __setattr__, we need to attach these attributes using
104    # the superclass' setattr.
105    object.__setattr__(self, 'type', type_)
106    object.__setattr__(self, '_stack', [])
107    if not hasattr(type_, 'no_root'):
108      self.enter()
109
110  def enter(self):
111    self._stack.append(self.type())
112
113  def exit(self):
114    return self._stack.pop()
115
116  @property
117  def level(self):
118    return len(self._stack)
119
120  @property
121  def value(self):
122    return self._stack[-1]
123
124  def __iter__(self):
125    return iter(self._stack)
126
127  def __getattr__(self, key):
128    return getattr(self._stack[-1], key)
129
130  def __setattr__(self, key, value):
131    setattr(self._stack[-1], key, value)
132
133
134class _State(object):
135  """Supporting class for nested scope variable space for converter.Base.
136
137  This structure offers syntactic sugar over a dict of stacks of objects
138  of known type. These structures are useful to keep state during AST walks.
139  Multiple different scopes can be tracked in parallel. For example:
140
141    s = _State()
142
143    s[foo].enter()
144    s[bar].enter()  # this will not affect s[foo]
145
146  Element access has special semantics:
147    * keys are a data type
148    * element values are _StateStack(type=key) objects
149    * missing elements are automatically added, similarly to defaultdict
150
151  For example, the following block :
152
153    _State s
154    s[Foo]
155
156  Is equivalent to:
157
158    s = {}
159    if Foo not in s:
160      s[Foo] = Foo()
161    s[Foo]
162
163  See Base for how it's used.
164  """
165
166  def __init__(self):
167    self._value = {}
168
169  def __getitem__(self, key):
170    if key not in self._value:
171      self._value[key] = _StateStack(key)
172    return self._value[key]
173
174
175class Base(gast.NodeTransformer):
176  """Base class for general-purpose code transformers transformers.
177
178  This is an extension of ast.NodeTransformer that provides a few additional
179  functions, like state tracking within the scope of arbitrary node, helpers
180  for processing code blocks, debugging, mapping of transformed code to
181  original code, and others.
182
183  Scope-local state tracking: to keep state across nodes, at the level of
184  (possibly nested) scopes, use enter/exit_local_scope and set/get_local.
185  You must call enter/exit_local_scope manually, but the transformer detects
186  when they are not properly paired.
187
188  The transformer allows keeping state across calls to visit_* that is local to
189  arbitrary nodes and their descendants, using the self.state attribute.
190  Multiple independent scopes are allowed and automatically constructed.
191
192  For example, to keep track of the If node that encloses any Name node, one can
193  write:
194
195    class FooType(object):
196
197      def __init__(self):
198        self.foo_property = None
199
200    class DummyTransformer(Base):
201
202      def visit_If(self, node):
203        self.state[FooType].enter()
204        self.state[FooType].foo_property = node
205
206      def visit_Name(self, node):
207        self.state[FooType].foo_property  # will hold the innermost enclosing if
208  """
209
210  # TODO(mdan): Document all extra features.
211
212  def __init__(self, ctx):
213    """Initialize the transformer.
214
215    Subclasses should call this.
216
217    Args:
218      ctx: A Context object.
219    """
220    self._lineno = 0
221    self._col_offset = 0
222    self.ctx = ctx
223    self._enclosing_entities = []
224
225    # A stack that allows keeping mutable, scope-local state where scopes may be
226    # nested. For example, it can be used to track the usage of break
227    # statements in each loop, where loops may be nested.
228    self._local_scope_state = []
229    self.enter_local_scope()
230
231    # Allows scoping of local variables to keep state across calls to visit_*
232    # methods. Multiple scope hierchies may exist and are keyed by tag. A scope
233    # is valid at one or more nodes and all its children. Scopes created in
234    # child nodes supersede their parent. Scopes are isolated from one another.
235    self.state = _State()
236
237  @property
238  def enclosing_entities(self):
239    return tuple(self._enclosing_entities)
240
241  @property
242  def local_scope_level(self):
243    return len(self._local_scope_state)
244
245  def enter_local_scope(self, inherit=None):
246    """Deprecated.
247
248    Use self.state instead.
249
250    Marks entry into a new local scope.
251
252    Args:
253      inherit: Optional enumerable of variable names to copy from the parent
254        scope.
255    """
256    scope_entered = {}
257    if inherit:
258      this_scope = self._local_scope_state[-1]
259      for name in inherit:
260        if name in this_scope:
261          scope_entered[name] = this_scope[name]
262    self._local_scope_state.append(scope_entered)
263
264  def exit_local_scope(self, keep=None):
265    """Deprecated.
266
267    Use self.state instead.
268
269    Marks exit from the current local scope.
270
271    Args:
272      keep: Optional enumerable of variable names to copy into the parent scope.
273
274    Returns:
275      A dict containing the scope that has just been exited.
276    """
277    scope_left = self._local_scope_state.pop()
278    if keep:
279      this_scope = self._local_scope_state[-1]
280      for name in keep:
281        if name in scope_left:
282          this_scope[name] = scope_left[name]
283    return scope_left
284
285  def set_local(self, name, value):
286    """Deprecated. Use self.state instead."""
287    self._local_scope_state[-1][name] = value
288
289  def get_local(self, name, default=None):
290    """Deprecated. Use self.state instead."""
291    return self._local_scope_state[-1].get(name, default)
292
293  def debug_print(self, node):
294    """Helper method useful for debugging. Prints the AST."""
295    if __debug__:
296      print(pretty_printer.fmt(node))
297    return node
298
299  def debug_print_src(self, node):
300    """Helper method useful for debugging. Prints the AST as code."""
301    if __debug__:
302      print(compiler.ast_to_source(node))
303    return node
304
305  def create_assignment(self, target, expression):
306    template = """
307      target = expression
308    """
309    return templates.replace(template, target=target, expression=expression)
310
311  def visit_block(self, nodes, before_visit=None, after_visit=None):
312    """A more powerful version of generic_visit for statement blocks.
313
314    An example of a block is the body of an if statement.
315
316    This function allows specifying a postprocessing callback (the
317    after_visit argument) argument which can be used to move nodes to a new
318    destination. This is done by after_visit by returning a non-null
319    second return value, e.g. return new_node, new_destination.
320
321    For example, a transformer could perform the following move:
322
323        foo()
324        bar()
325        baz()
326
327        foo()
328        if cond:
329          bar()
330          baz()
331
332    The above could be done with a postprocessor of this kind:
333
334        def after_visit(node):
335          if node_is_function_call(bar):
336            new_container_node = build_cond()
337            new_container_node.body.append(node)
338            return new_container_node, new_container_node.body
339          else:
340            # Once we set a new destination, all subsequent items will be
341            # moved to it, so we don't need to explicitly handle baz.
342            return node, None
343
344    Args:
345      nodes: enumerable of AST node objects. If None, the function returns None.
346      before_visit: optional callable that is called before visiting each item
347        in nodes
348      after_visit: optional callable that takes in an AST node and returns a
349        tuple (new_node, new_destination). It is called after visiting each item
350        in nodes. Is used in the same was as the
351          visit_* methods: new_node will replace the node; if not None,
352            new_destination must be a list, and subsequent nodes will be placed
353            in this list instead of the list returned by visit_block.
354
355    Returns:
356      A list of AST node objects containing the transformed items fron nodes,
357      except those nodes that have been relocated using after_visit.
358    """
359    if nodes is None:
360      return None
361
362    results = []
363    node_destination = results
364    for node in nodes:
365      if before_visit:
366        # TODO(mdan): We can modify node here too, if ever needed.
367        before_visit()
368
369      replacement = self.visit(node)
370
371      if after_visit and replacement:
372        replacement, new_destination = after_visit(replacement)
373      else:
374        new_destination = None
375
376      if replacement:
377        if isinstance(replacement, (list, tuple)):
378          node_destination.extend(replacement)
379        else:
380          node_destination.append(replacement)
381
382      # Allow the postprocessor to reroute the remaining nodes to a new list.
383      if new_destination is not None:
384        node_destination = new_destination
385    return results
386
387  # TODO(mdan): Remove.
388  def apply_to_single_assignments(self, targets, values, apply_fn):
389    """Applies a function to each individual assignment.
390
391    This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
392    It tries to break down the unpacking if possible. In effect, it has the same
393    effect as passing the assigned values in SSA form to apply_fn.
394
395    Examples:
396
397    The following will result in apply_fn(a, c), apply_fn(b, d):
398
399        a, b = c, d
400
401    The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
402
403        a, b = c
404
405    The following will result in apply_fn(a, (b, c)):
406
407        a = b, c
408
409    It uses the visitor pattern to allow subclasses to process single
410    assignments individually.
411
412    Args:
413      targets: list, tuple of or individual AST node. Should be used with the
414        targets field of an ast.Assign node.
415      values: an AST node.
416      apply_fn: a function of a single argument, which will be called with the
417        respective nodes of each single assignment. The signature is
418        apply_fn(target, value), no return value.
419    """
420    if not isinstance(targets, (list, tuple)):
421      targets = (targets,)
422    for target in targets:
423      if isinstance(target, (gast.Tuple, gast.List)):
424        for i in range(len(target.elts)):
425          target_el = target.elts[i]
426          if isinstance(values, (gast.Tuple, gast.List)):
427            value_el = values.elts[i]
428          else:
429            value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store())
430          self.apply_to_single_assignments(target_el, value_el, apply_fn)
431      else:
432        # TODO(mdan): Look into allowing to rewrite the AST here.
433        apply_fn(target, values)
434
435  def _get_source(self, node):
436    try:
437      source, _ = compiler.ast_to_source(node)
438      return source
439    # pylint: disable=broad-except
440    # This function is used for error reporting.  If an exception occurs here,
441    # it should be suppressed, in favor of emitting as informative a message
442    # about the original error as possible.
443    except Exception:
444      return '<could not convert AST to source>'
445
446  def visit(self, node):
447    if not isinstance(node, gast.AST):
448      # This is not that uncommon a mistake: various node bodies are lists, for
449      # example, posing a land mine for transformers that need to recursively
450      # call `visit`.  The error needs to be raised before the exception handler
451      # below is installed, because said handler will mess up if `node` is not,
452      # in fact, a node.
453      msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
454             ' visit lists of nodes, use "visit_block" instead').format(
455                 type(node))
456      raise ValueError(msg)
457
458    did_enter_function = False
459    local_scope_size_at_entry = len(self._local_scope_state)
460    processing_expr_node = False
461
462    parent_origin = self.ctx.current_origin
463    if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)):
464      did_enter_function = True
465    elif isinstance(node, gast.Expr):
466      processing_expr_node = True
467
468    if did_enter_function:
469      self._enclosing_entities.append(node)
470
471    if anno.hasanno(node, anno.Basic.ORIGIN):
472      self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)
473
474    if processing_expr_node:
475      entry_expr_value = node.value
476
477    if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
478      result = super(Base, self).visit(node)
479    self.ctx.current_origin = parent_origin
480
481    # Adjust for consistency: replacing the value of an Expr with
482    # an Assign node removes the need for the Expr node.
483    if processing_expr_node:
484      if isinstance(result, gast.Expr) and result.value != entry_expr_value:
485        # When the replacement is a list, it is assumed that the list came
486        # from a template that contained a number of statements, which
487        # themselves are standalone and don't require an enclosing Expr.
488        if isinstance(result.value,
489                      (list, tuple, gast.Assign, gast.AugAssign)):
490          result = result.value
491
492    # On exception, the local scope integrity is not guaranteed.
493    if did_enter_function:
494      self._enclosing_entities.pop()
495
496    if local_scope_size_at_entry != len(self._local_scope_state):
497      raise AssertionError(
498          'Inconsistent local scope stack. Before entering node %s, the'
499          ' stack had length %d, after exit it has length %d. This'
500          ' indicates enter_local_scope and exit_local_scope are not'
501          ' well paired.' % (node, local_scope_size_at_entry,
502                             len(self._local_scope_state)))
503    return result
504