• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# TreeFragments - parsing of strings to trees
3#
4
5import re
6from StringIO import StringIO
7from Scanning import PyrexScanner, StringSourceDescriptor
8from Symtab import ModuleScope
9import PyrexTypes
10from Visitor import VisitorTransform
11from Nodes import Node, StatListNode
12from ExprNodes import NameNode
13import Parsing
14import Main
15import UtilNodes
16
17"""
18Support for parsing strings into code trees.
19"""
20
21class StringParseContext(Main.Context):
22    def __init__(self, name, include_directories=None):
23        if include_directories is None: include_directories = []
24        Main.Context.__init__(self, include_directories, {},
25                              create_testscope=False)
26        self.module_name = name
27
28    def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
29        if module_name not in (self.module_name, 'cython'):
30            raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
31        return ModuleScope(module_name, parent_module = None, context = self)
32
33def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None,
34                       context=None, allow_struct_enum_decorator=False):
35    """
36    Utility method to parse a (unicode) string of code. This is mostly
37    used for internal Cython compiler purposes (creating code snippets
38    that transforms should emit, as well as unit testing).
39
40    code - a unicode string containing Cython (module-level) code
41    name - a descriptive name for the code source (to use in error messages etc.)
42
43    RETURNS
44
45    The tree, i.e. a ModuleNode. The ModuleNode's scope attribute is
46    set to the scope used when parsing.
47    """
48    if context is None:
49        context = StringParseContext(name)
50    # Since source files carry an encoding, it makes sense in this context
51    # to use a unicode string so that code fragments don't have to bother
52    # with encoding. This means that test code passed in should not have an
53    # encoding header.
54    assert isinstance(code, unicode), "unicode code snippets only please"
55    encoding = "UTF-8"
56
57    module_name = name
58    if initial_pos is None:
59        initial_pos = (name, 1, 0)
60    code_source = StringSourceDescriptor(name, code)
61
62    scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0)
63
64    buf = StringIO(code)
65
66    scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
67                     scope = scope, context = context, initial_pos = initial_pos)
68    ctx = Parsing.Ctx(allow_struct_enum_decorator=allow_struct_enum_decorator)
69
70    if level is None:
71        tree = Parsing.p_module(scanner, 0, module_name, ctx=ctx)
72        tree.scope = scope
73        tree.is_pxd = False
74    else:
75        tree = Parsing.p_code(scanner, level=level, ctx=ctx)
76
77    tree.scope = scope
78    return tree
79
80class TreeCopier(VisitorTransform):
81    def visit_Node(self, node):
82        if node is None:
83            return node
84        else:
85            c = node.clone_node()
86            self.visitchildren(c)
87            return c
88
89class ApplyPositionAndCopy(TreeCopier):
90    def __init__(self, pos):
91        super(ApplyPositionAndCopy, self).__init__()
92        self.pos = pos
93
94    def visit_Node(self, node):
95        copy = super(ApplyPositionAndCopy, self).visit_Node(node)
96        copy.pos = self.pos
97        return copy
98
99class TemplateTransform(VisitorTransform):
100    """
101    Makes a copy of a template tree while doing substitutions.
102
103    A dictionary "substitutions" should be passed in when calling
104    the transform; mapping names to replacement nodes. Then replacement
105    happens like this:
106     - If an ExprStatNode contains a single NameNode, whose name is
107       a key in the substitutions dictionary, the ExprStatNode is
108       replaced with a copy of the tree given in the dictionary.
109       It is the responsibility of the caller that the replacement
110       node is a valid statement.
111     - If a single NameNode is otherwise encountered, it is replaced
112       if its name is listed in the substitutions dictionary in the
113       same way. It is the responsibility of the caller to make sure
114       that the replacement nodes is a valid expression.
115
116    Also a list "temps" should be passed. Any names listed will
117    be transformed into anonymous, temporary names.
118
119    Currently supported for tempnames is:
120    NameNode
121    (various function and class definition nodes etc. should be added to this)
122
123    Each replacement node gets the position of the substituted node
124    recursively applied to every member node.
125    """
126
127    temp_name_counter = 0
128
129    def __call__(self, node, substitutions, temps, pos):
130        self.substitutions = substitutions
131        self.pos = pos
132        tempmap = {}
133        temphandles = []
134        for temp in temps:
135            TemplateTransform.temp_name_counter += 1
136            handle = UtilNodes.TempHandle(PyrexTypes.py_object_type)
137            tempmap[temp] = handle
138            temphandles.append(handle)
139        self.tempmap = tempmap
140        result = super(TemplateTransform, self).__call__(node)
141        if temps:
142            result = UtilNodes.TempsBlockNode(self.get_pos(node),
143                                              temps=temphandles,
144                                              body=result)
145        return result
146
147    def get_pos(self, node):
148        if self.pos:
149            return self.pos
150        else:
151            return node.pos
152
153    def visit_Node(self, node):
154        if node is None:
155            return None
156        else:
157            c = node.clone_node()
158            if self.pos is not None:
159                c.pos = self.pos
160            self.visitchildren(c)
161            return c
162
163    def try_substitution(self, node, key):
164        sub = self.substitutions.get(key)
165        if sub is not None:
166            pos = self.pos
167            if pos is None: pos = node.pos
168            return ApplyPositionAndCopy(pos)(sub)
169        else:
170            return self.visit_Node(node) # make copy as usual
171
172    def visit_NameNode(self, node):
173        temphandle = self.tempmap.get(node.name)
174        if temphandle:
175            # Replace name with temporary
176            return temphandle.ref(self.get_pos(node))
177        else:
178            return self.try_substitution(node, node.name)
179
180    def visit_ExprStatNode(self, node):
181        # If an expression-as-statement consists of only a replaceable
182        # NameNode, we replace the entire statement, not only the NameNode
183        if isinstance(node.expr, NameNode):
184            return self.try_substitution(node, node.expr.name)
185        else:
186            return self.visit_Node(node)
187
188def copy_code_tree(node):
189    return TreeCopier()(node)
190
191INDENT_RE = re.compile(ur"^ *")
192def strip_common_indent(lines):
193    "Strips empty lines and common indentation from the list of strings given in lines"
194    # TODO: Facilitate textwrap.indent instead
195    lines = [x for x in lines if x.strip() != u""]
196    minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines])
197    lines = [x[minindent:] for x in lines]
198    return lines
199
200class TreeFragment(object):
201    def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[], level=None, initial_pos=None):
202        if isinstance(code, unicode):
203            def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
204
205            fmt_code = fmt(code)
206            fmt_pxds = {}
207            for key, value in pxds.iteritems():
208                fmt_pxds[key] = fmt(value)
209            mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level, initial_pos=initial_pos)
210            if level is None:
211                t = t.body # Make sure a StatListNode is at the top
212            if not isinstance(t, StatListNode):
213                t = StatListNode(pos=mod.pos, stats=[t])
214            for transform in pipeline:
215                if transform is None:
216                    continue
217                t = transform(t)
218            self.root = t
219        elif isinstance(code, Node):
220            if pxds != {}: raise NotImplementedError()
221            self.root = code
222        else:
223            raise ValueError("Unrecognized code format (accepts unicode and Node)")
224        self.temps = temps
225
226    def copy(self):
227        return copy_code_tree(self.root)
228
229    def substitute(self, nodes={}, temps=[], pos = None):
230        return TemplateTransform()(self.root,
231                                   substitutions = nodes,
232                                   temps = self.temps + temps, pos = pos)
233
234class SetPosTransform(VisitorTransform):
235    def __init__(self, pos):
236        super(SetPosTransform, self).__init__()
237        self.pos = pos
238
239    def visit_Node(self, node):
240        node.pos = self.pos
241        self.visitchildren(node)
242        return node
243