• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A node transformer that includes utilities for SCT."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import enum
23
24import gast
25
26from tensorflow.python.autograph.pyct import anno
27from tensorflow.python.autograph.pyct import parser
28from tensorflow.python.autograph.pyct import pretty_printer
29from tensorflow.python.autograph.pyct import templates
30
31
32class AnalysisLevel(enum.IntEnum):
33
34  NONE = 0
35  ACTIVITY = 1
36  DEFINEDNESS = 2
37  LIVENESS = 3
38
39
40# TODO(znado): Use namedtuple.
41class Context(object):
42  """Contains information about a source code transformation.
43
44  This object is mutable, and is updated during conversion. Not thread safe.
45
46  Attributes:
47    info: EntityInfo, immutable.
48    namer: naming.Namer.
49    current_origin: origin_info.OriginInfo, holds the OriginInfo of the last
50      AST node to be processed successfully. Useful for error handling.
51    user: An user-supplied context object. The object is opaque to the
52      infrastructure, but will pe passed through to all custom transformations.
53  """
54
55  def __init__(self, info, namer, user_context):
56    self.info = info
57    self.namer = namer
58    self.current_origin = None
59    self.user = user_context
60
61
62# TODO(mdan): Move to a standalone file.
63class EntityInfo(
64    collections.namedtuple(
65        'EntityInfo',
66        ('name', 'source_code', 'source_file', 'future_features', 'namespace'))
67):
68  """Contains information about a Python entity.
69
70  Immutable.
71
72  Examples of entities include functions and classes.
73
74  Attributes:
75    name: The name that identifies this entity.
76    source_code: The entity's source code.
77    source_file: The entity's source file.
78    future_features: Tuple[Text], the future features that this entity was
79      compiled with. See
80      https://docs.python.org/2/reference/simple_stmts.html#future.
81    namespace: Dict[str, ], containing symbols visible to the entity (excluding
82      parameters).
83  """
84  pass
85
86
87class _StateStack(object):
88  """Templated context manager.
89
90  This class provides syntactic sugar for a stack of objects of known
91  type. It allows accessing attributes of the object at the top of the stack
92  directly against this object, which allows for very terse syntax.
93
94  For example, this code:
95
96    stack = _StateStack(Foo)
97    stack.enter()
98    stack.bar
99
100  Is equivalent to:
101
102    stack = []
103    stack.append(Foo())
104    foo = stack[-1]
105    foo.bar
106
107  See _State for more on how this is used.
108
109  Attributes:
110    type: Any, the type of objects that this stack holds
111    level: int, the current stack depth
112    stack: List[Any], the actual stack
113    value: Any, the instance of the object at the top of the stack
114  """
115
116  def __init__(self, type_):
117    # Because we override __setattr__, we need to attach these attributes using
118    # the superclass' setattr.
119    object.__setattr__(self, 'type', type_)
120    object.__setattr__(self, '_stack', [])
121    if not hasattr(type_, 'no_root'):
122      self.enter()
123
124  def __enter__(self):
125    self.enter()
126    return self
127
128  def __exit__(self, exc_type, exc_value, traceback):
129    self.exit()
130
131  def enter(self):
132    self._stack.append(self.type())
133
134  def exit(self):
135    self._stack.pop()
136
137  @property
138  def stack(self):
139    return self._stack
140
141  @property
142  def level(self):
143    return len(self._stack)
144
145  @property
146  def value(self):
147    return self._stack[-1]
148
149  def __iter__(self):
150    return iter(self._stack)
151
152  def __getattr__(self, key):
153    return getattr(self._stack[-1], key)
154
155  def __setattr__(self, key, value):
156    setattr(self._stack[-1], key, value)
157
158
159class _State(object):
160  """Syntactic sugar for accessing an instance of a StateStack context manager.
161
162  This structure offers syntactic sugar over a dict of stacks of objects
163  of known type. These structures are useful to keep state during AST walks.
164  Multiple different scopes can be tracked in parallel. For example:
165
166    s = _State()
167
168    s[foo].enter()
169    s[bar].enter()  # this will not affect s[foo]
170
171  Element access has special semantics:
172    * keys are a data type
173    * element values are _StateStack(type=key) objects
174    * missing elements are automatically added, similarly to defaultdict
175
176  For example, the following block :
177
178    _State s
179    s[Foo]
180
181  Is equivalent to:
182
183    s = {}
184    if Foo not in s:
185      s[Foo] = Foo()
186    s[Foo]
187
188  See Base for how it's used.
189  """
190
191  def __init__(self):
192    self._value = {}
193
194  def __getitem__(self, key):
195    if key not in self._value:
196      self._value[key] = _StateStack(key)
197    return self._value[key]
198
199
200class NodeStateTracker(object):
201  """Base class for general-purpose Python code transformation.
202
203  This abstract class provides helpful functions, like state tracking within
204  the scope of arbitrary node, helpers for processing code blocks, debugging,
205  mapping of transformed code to original code, and others.
206
207  Scope-local state tracking: to keep state across nodes, at the level of
208  (possibly nested) scopes, use enter/exit_local_scope and set/get_local.
209  You must call enter/exit_local_scope manually, but the transformer detects
210  when they are not properly paired.
211
212  The transformer allows keeping state across calls that is local
213  to arbitrary nodes and their descendants, using the self.state attribute.
214  Multiple independent scopes are allowed and automatically constructed.
215
216  For example, to keep track of the `If` node that encloses any `Name` node,
217  one can write:
218
219  ```
220    class FooType(object):
221
222      def __init__(self):
223        self.foo_property = None
224
225    class DummyTransformer(NodeStateTracker, ast.NodeTransformer):
226
227      def visit_If(self, node):
228        self.state[FooType].enter()
229        self.state[FooType].foo_property = node
230        node = self.veneric_visit(node)
231        self.state[FooType].exit()
232        return node
233
234      def visit_Name(self, node):
235        self.state[FooType].foo_property  # will hold the innermost enclosing if
236  ```
237
238  Alternatively, the `enter()`/`exit()` calls can be managed by a `with`
239  statement:
240
241  ```
242      def visit_If(self, node):
243        with self.state[FooType] as foo:
244          foo.foo_property = node
245          return self.generic_visit(node)
246  ```
247  """
248
249  # TODO(mdan): Document all extra features.
250
251  def __init__(self, ctx):
252    """Initialize the transformer.
253
254    Subclasses should call this.
255
256    Args:
257      ctx: A Context object.
258    """
259    self._lineno = 0
260    self._col_offset = 0
261    self.ctx = ctx
262
263    # Allows scoping of local variables to keep state across calls to visit_*
264    # methods. Multiple scope hierarchies may exist and are keyed by tag. A
265    # scope is valid at one or more nodes and all its children. Scopes created
266    # in child nodes supersede their parent. Scopes are isolated from one
267    # another.
268    self.state = _State()
269
270  def debug_print(self, node):
271    """Helper method useful for debugging. Prints the AST."""
272    if __debug__:
273      print(pretty_printer.fmt(node))
274    return node
275
276  def debug_print_src(self, node):
277    """Helper method useful for debugging. Prints the AST as code."""
278    if __debug__:
279      print(parser.unparse(node))
280    return node
281
282  def visit_block(self, nodes, before_visit=None, after_visit=None):
283    """A more powerful version of generic_visit for statement blocks.
284
285    An example of a block is the body of an if statement.
286
287    This function allows specifying a postprocessing callback (the
288    after_visit argument) argument which can be used to move nodes to a new
289    destination. This is done by after_visit by returning a non-null
290    second return value, e.g. return new_node, new_destination.
291
292    For example, a transformer could perform the following move:
293
294        foo()
295        bar()
296        baz()
297
298        foo()
299        if cond:
300          bar()
301          baz()
302
303    The above could be done with a postprocessor of this kind:
304
305        def after_visit(node):
306          if node_is_function_call(bar):
307            new_container_node = build_cond()
308            new_container_node.body.append(node)
309            return new_container_node, new_container_node.body
310          else:
311            # Once we set a new destination, all subsequent items will be
312            # moved to it, so we don't need to explicitly handle baz.
313            return node, None
314
315    Args:
316      nodes: enumerable of AST node objects. If None, the function returns None.
317      before_visit: optional callable that is called before visiting each item
318        in nodes
319      after_visit: optional callable that takes in an AST node and returns a
320        tuple (new_node, new_destination). It is called after visiting each item
321        in nodes. Is used in the same was as the
322          visit_* methods: new_node will replace the node; if not None,
323            new_destination must be a list, and subsequent nodes will be placed
324            in this list instead of the list returned by visit_block.
325
326    Returns:
327      A list of AST node objects containing the transformed items fron nodes,
328      except those nodes that have been relocated using after_visit.
329    """
330    if nodes is None:
331      return None
332
333    results = []
334    node_destination = results
335    for node in nodes:
336      if before_visit:
337        # TODO(mdan): We can modify node here too, if ever needed.
338        before_visit()
339
340      replacement = self.visit(node)
341
342      if after_visit and replacement:
343        replacement, new_destination = after_visit(replacement)
344      else:
345        new_destination = None
346
347      if replacement:
348        if isinstance(replacement, (list, tuple)):
349          node_destination.extend(replacement)
350        else:
351          node_destination.append(replacement)
352
353      # Allow the postprocessor to reroute the remaining nodes to a new list.
354      if new_destination is not None:
355        node_destination = new_destination
356    return results
357
358
359# TODO(mdan): Rename to PythonCodeTransformer.
360class Base(NodeStateTracker, gast.NodeTransformer):
361  """Base class for general-purpose Python-to-Python code transformation.
362
363  This is an extension of ast.NodeTransformer that provides the additional
364  functions offered by NodeStateTracker.
365  """
366
367  def create_assignment(self, target, expression):
368    template = """
369      target = expression
370    """
371    return templates.replace(template, target=target, expression=expression)
372
373  # TODO(mdan): Remove.
374  def apply_to_single_assignments(self, targets, values, apply_fn):
375    """Applies a function to each individual assignment.
376
377    This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
378    It tries to break down the unpacking if possible. In effect, it has the same
379    effect as passing the assigned values in SSA form to apply_fn.
380
381    Examples:
382
383    The following will result in apply_fn(a, c), apply_fn(b, d):
384
385        a, b = c, d
386
387    The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
388
389        a, b = c
390
391    The following will result in apply_fn(a, (b, c)):
392
393        a = b, c
394
395    It uses the visitor pattern to allow subclasses to process single
396    assignments individually.
397
398    Args:
399      targets: list, tuple of or individual AST node. Should be used with the
400        targets field of an ast.Assign node.
401      values: an AST node.
402      apply_fn: a function of a single argument, which will be called with the
403        respective nodes of each single assignment. The signature is
404        apply_fn(target, value), no return value.
405    """
406    if not isinstance(targets, (list, tuple)):
407      targets = (targets,)
408    for target in targets:
409      if isinstance(target, (gast.Tuple, gast.List)):
410        for i in range(len(target.elts)):
411          target_el = target.elts[i]
412          if isinstance(values, (gast.Tuple, gast.List)):
413            value_el = values.elts[i]
414          else:
415            value_el = gast.Subscript(values, i, ctx=gast.Store())
416          self.apply_to_single_assignments(target_el, value_el, apply_fn)
417      else:
418        # TODO(mdan): Look into allowing to rewrite the AST here.
419        apply_fn(target, values)
420
421  def visit(self, node):
422    if not isinstance(node, gast.AST):
423      # This is not that uncommon a mistake: various node bodies are lists, for
424      # example, posing a land mine for transformers that need to recursively
425      # call `visit`.  The error needs to be raised before the exception handler
426      # below is installed, because said handler will mess up if `node` is not,
427      # in fact, a node.
428      msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
429             ' visit lists of nodes, use "visit_block" instead').format(
430                 type(node))
431      raise ValueError(msg)
432
433    if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
434      return node
435
436    parent_origin = self.ctx.current_origin
437    if anno.hasanno(node, anno.Basic.ORIGIN):
438      self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)
439
440    try:
441      processing_expr_node = isinstance(node, gast.Expr)
442      if processing_expr_node:
443        entry_expr_value = node.value
444
445      result = super(Base, self).visit(node)
446
447      # Adjust for consistency: replacing the value of an Expr with
448      # an Assign node removes the need for the Expr node.
449      if (processing_expr_node and isinstance(result, gast.Expr) and
450          (result.value is not entry_expr_value)):
451        # When the replacement is a list, it is assumed that the list came
452        # from a template that contained a number of statements, which
453        # themselves are standalone and don't require an enclosing Expr.
454        if isinstance(result.value,
455                      (list, tuple, gast.Assign, gast.AugAssign)):
456          result = result.value
457
458      # By default, all replacements receive the origin info of the replaced
459      # node.
460      if result is not node and result is not None:
461        inherited_origin = anno.getanno(
462            node, anno.Basic.ORIGIN, default=parent_origin)
463        if inherited_origin is not None:
464          nodes_to_adjust = result
465          if isinstance(result, (list, tuple)):
466            nodes_to_adjust = result
467          else:
468            nodes_to_adjust = (result,)
469          for n in nodes_to_adjust:
470            if not anno.hasanno(n, anno.Basic.ORIGIN):
471              anno.setanno(n, anno.Basic.ORIGIN, inherited_origin)
472    finally:
473      self.ctx.current_origin = parent_origin
474
475    return result
476
477
478class CodeGenerator(NodeStateTracker, gast.NodeVisitor):
479  """Base class for general-purpose Python-to-string code transformation.
480
481  Similar to Base, but outputs arbitrary strings instead of a Python AST.
482
483  This uses the same visitor mechanism that the standard NodeVisitor uses,
484  meaning that subclasses write handlers for the different kinds of nodes.
485  New code is generated using the emit method, which appends to a code buffer
486  that can be afterwards obtained from code_buffer.
487
488  Example:
489
490    class SimpleCodeGen(CodeGenerator):
491
492      def visitIf(self, node):
493        self.emit('if ')
494        self.visit(node.test)
495        self.emit(' { ')
496        self.visit(node.body)
497        self.emit(' } else { ')
498        self.visit(node.orelse)
499        self.emit(' } ')
500
501    node = ast.parse(...)
502    gen = SimpleCodeGen()
503    gen.visit(node)
504    # gen.code_buffer contains the resulting code
505  """
506
507  def __init__(self, ctx):
508    super(CodeGenerator, self).__init__(ctx)
509
510    self._output_code = ''
511    self.source_map = {}
512
513  def emit(self, code):
514    self._output_code += code
515
516  @property
517  def code_buffer(self):
518    return self._output_code
519
520  def visit(self, node):
521    if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
522      return
523
524    parent_origin = self.ctx.current_origin
525    eof_before = len(self._output_code)
526    if anno.hasanno(node, anno.Basic.ORIGIN):
527      self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)
528
529    try:
530      ret = super(CodeGenerator, self).visit(node)
531
532      # By default, all replacements receive the origin info of the replaced
533      # node.
534      eof_after = len(self._output_code)
535      if eof_before - eof_after:
536        inherited_origin = anno.getanno(
537            node, anno.Basic.ORIGIN, default=parent_origin)
538        if inherited_origin is not None:
539          self.source_map[(eof_before, eof_after)] = inherited_origin
540      return ret
541    finally:
542      self.ctx.current_origin = parent_origin
543