• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import Cython.Compiler.Errors as Errors
2from Cython.CodeWriter import CodeWriter
3from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
4from Cython.Compiler.Visitor import TreeVisitor, VisitorTransform
5from Cython.Compiler import TreePath
6
7import unittest
8import os, sys
9import tempfile
10
11
12class NodeTypeWriter(TreeVisitor):
13    def __init__(self):
14        super(NodeTypeWriter, self).__init__()
15        self._indents = 0
16        self.result = []
17
18    def visit_Node(self, node):
19        if not self.access_path:
20            name = u"(root)"
21        else:
22            tip = self.access_path[-1]
23            if tip[2] is not None:
24                name = u"%s[%d]" % tip[1:3]
25            else:
26                name = tip[1]
27
28        self.result.append(u"  " * self._indents +
29                           u"%s: %s" % (name, node.__class__.__name__))
30        self._indents += 1
31        self.visitchildren(node)
32        self._indents -= 1
33
34
35def treetypes(root):
36    """Returns a string representing the tree by class names.
37    There's a leading and trailing whitespace so that it can be
38    compared by simple string comparison while still making test
39    cases look ok."""
40    w = NodeTypeWriter()
41    w.visit(root)
42    return u"\n".join([u""] + w.result + [u""])
43
44
45class CythonTest(unittest.TestCase):
46
47    def setUp(self):
48        self.listing_file = Errors.listing_file
49        self.echo_file = Errors.echo_file
50        Errors.listing_file = Errors.echo_file = None
51
52    def tearDown(self):
53        Errors.listing_file = self.listing_file
54        Errors.echo_file = self.echo_file
55
56    def assertLines(self, expected, result):
57        "Checks that the given strings or lists of strings are equal line by line"
58        if not isinstance(expected, list): expected = expected.split(u"\n")
59        if not isinstance(result, list): result = result.split(u"\n")
60        for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
61            self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
62        self.assertEqual(len(expected), len(result),
63            "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
64
65    def codeToLines(self, tree):
66        writer = CodeWriter()
67        writer.write(tree)
68        return writer.result.lines
69
70    def codeToString(self, tree):
71        return "\n".join(self.codeToLines(tree))
72
73    def assertCode(self, expected, result_tree):
74        result_lines = self.codeToLines(result_tree)
75
76        expected_lines = strip_common_indent(expected.split("\n"))
77
78        for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
79            self.assertEqual(expected_line, line, "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
80        self.assertEqual(len(result_lines), len(expected_lines),
81            "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
82
83    def assertNodeExists(self, path, result_tree):
84        self.assertNotEqual(TreePath.find_first(result_tree, path), None,
85                            "Path '%s' not found in result tree" % path)
86
87    def fragment(self, code, pxds={}, pipeline=[]):
88        "Simply create a tree fragment using the name of the test-case in parse errors."
89        name = self.id()
90        if name.startswith("__main__."): name = name[len("__main__."):]
91        name = name.replace(".", "_")
92        return TreeFragment(code, name, pxds, pipeline=pipeline)
93
94    def treetypes(self, root):
95        return treetypes(root)
96
97    def should_fail(self, func, exc_type=Exception):
98        """Calls "func" and fails if it doesn't raise the right exception
99        (any exception by default). Also returns the exception in question.
100        """
101        try:
102            func()
103            self.fail("Expected an exception of type %r" % exc_type)
104        except exc_type, e:
105            self.assert_(isinstance(e, exc_type))
106            return e
107
108    def should_not_fail(self, func):
109        """Calls func and succeeds if and only if no exception is raised
110        (i.e. converts exception raising into a failed testcase). Returns
111        the return value of func."""
112        try:
113            return func()
114        except:
115            self.fail(str(sys.exc_info()[1]))
116
117
118class TransformTest(CythonTest):
119    """
120    Utility base class for transform unit tests. It is based around constructing
121    test trees (either explicitly or by parsing a Cython code string); running
122    the transform, serialize it using a customized Cython serializer (with
123    special markup for nodes that cannot be represented in Cython),
124    and do a string-comparison line-by-line of the result.
125
126    To create a test case:
127     - Call run_pipeline. The pipeline should at least contain the transform you
128       are testing; pyx should be either a string (passed to the parser to
129       create a post-parse tree) or a node representing input to pipeline.
130       The result will be a transformed result.
131
132     - Check that the tree is correct. If wanted, assertCode can be used, which
133       takes a code string as expected, and a ModuleNode in result_tree
134       (it serializes the ModuleNode to a string and compares line-by-line).
135
136    All code strings are first stripped for whitespace lines and then common
137    indentation.
138
139    Plans: One could have a pxd dictionary parameter to run_pipeline.
140    """
141
142    def run_pipeline(self, pipeline, pyx, pxds={}):
143        tree = self.fragment(pyx, pxds).root
144        # Run pipeline
145        for T in pipeline:
146            tree = T(tree)
147        return tree
148
149
150class TreeAssertVisitor(VisitorTransform):
151    # actually, a TreeVisitor would be enough, but this needs to run
152    # as part of the compiler pipeline
153
154    def visit_CompilerDirectivesNode(self, node):
155        directives = node.directives
156        if 'test_assert_path_exists' in directives:
157            for path in directives['test_assert_path_exists']:
158                if TreePath.find_first(node, path) is None:
159                    Errors.error(
160                        node.pos,
161                        "Expected path '%s' not found in result tree" % path)
162        if 'test_fail_if_path_exists' in directives:
163            for path in directives['test_fail_if_path_exists']:
164                if TreePath.find_first(node, path) is not None:
165                    Errors.error(
166                        node.pos,
167                        "Unexpected path '%s' found in result tree" %  path)
168        self.visitchildren(node)
169        return node
170
171    visit_Node = VisitorTransform.recurse_to_children
172
173
174def unpack_source_tree(tree_file, dir=None):
175    if dir is None:
176        dir = tempfile.mkdtemp()
177    header = []
178    cur_file = None
179    f = open(tree_file)
180    try:
181        lines = f.readlines()
182    finally:
183        f.close()
184    del f
185    try:
186        for line in lines:
187            if line[:5] == '#####':
188                filename = line.strip().strip('#').strip().replace('/', os.path.sep)
189                path = os.path.join(dir, filename)
190                if not os.path.exists(os.path.dirname(path)):
191                    os.makedirs(os.path.dirname(path))
192                if cur_file is not None:
193                    f, cur_file = cur_file, None
194                    f.close()
195                cur_file = open(path, 'w')
196            elif cur_file is not None:
197                cur_file.write(line)
198            elif line.strip() and not line.lstrip().startswith('#'):
199                if line.strip() not in ('"""', "'''"):
200                    header.append(line)
201    finally:
202        if cur_file is not None:
203            cur_file.close()
204    return dir, ''.join(header)
205