• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding: utf-8
2from __future__ import unicode_literals, division, absolute_import, print_function
3
4import ast
5import _ast
6import unittest
7import os
8import sys
9
10import asn1crypto as module
11
12
13# This handles situations where an import is importing a function from a
14# dotted path, e.g. "from . import ident", and ident is a function, not a
15# submodule
16MOD_MAP = {
17}
18
19
20def add_mod(mod_name, imports):
21    """
22    Maps pre-defined module.function to module import names
23
24    :param mod_name:
25        A unicode string of a fully-qualified module name being imported
26
27    :param imports:
28        A set of unicode strings of the modules that are being imported
29    """
30
31    imports.add(MOD_MAP.get(mod_name, mod_name))
32
33
34def walk_ast(parent_node, modname, imports):
35    """
36    Walks the AST for a module finding any imports and recording them
37
38    :param parent_node:
39        A node from the _ast module
40
41    :param modname:
42        A unicode string of the module we are walking the AST of
43
44    :param imports:
45        A set of unicode strings of the imports that have been found so far
46    """
47
48    for node in ast.iter_child_nodes(parent_node):
49        if isinstance(node, _ast.Import):
50            if node.names[0].name.startswith(module.__name__):
51                add_mod(node.names[0].name, imports)
52
53        elif isinstance(node, _ast.ImportFrom):
54            if node.level > 0:
55                if modname == module.__name__:
56                    base_mod = module.__name__
57                else:
58                    base_mod = '.'.join(modname.split('.')[:-node.level])
59                if node.module:
60                    base_mod += '.' + node.module
61            else:
62                base_mod = node.module
63
64            if not base_mod.startswith(module.__name__):
65                continue
66
67            if node.level > 0 and not node.module:
68                for n in node.names:
69                    add_mod(base_mod + '.' + n.name, imports)
70            else:
71                add_mod(base_mod, imports)
72
73        elif isinstance(node, _ast.If):
74            for subast in node.body:
75                walk_ast(subast, modname, imports)
76            for subast in node.orelse:
77                walk_ast(subast, modname, imports)
78
79        elif sys.version_info >= (3, 3) and isinstance(node, _ast.Try):
80            for subast in node.body:
81                walk_ast(subast, modname, imports)
82            for subast in node.orelse:
83                walk_ast(subast, modname, imports)
84            for subast in node.finalbody:
85                walk_ast(subast, modname, imports)
86
87        elif sys.version_info < (3, 3) and isinstance(node, _ast.TryFinally):
88            for subast in node.body:
89                walk_ast(subast, modname, imports)
90            for subast in node.finalbody:
91                walk_ast(subast, modname, imports)
92
93        elif sys.version_info < (3, 3) and isinstance(node, _ast.TryExcept):
94            for subast in node.body:
95                walk_ast(subast, modname, imports)
96            for subast in node.orelse:
97                walk_ast(subast, modname, imports)
98
99
100class InitTests(unittest.TestCase):
101
102    def test_load_order(self):
103        deps = {}
104
105        mod_root = os.path.abspath(os.path.dirname(module.__file__))
106        files = []
107        for root, dnames, fnames in os.walk(mod_root):
108            for f in fnames:
109                if f.endswith('.py'):
110                    full_path = os.path.join(root, f)
111                    rel_path = full_path.replace(mod_root + os.sep, '')
112                    files.append((full_path, rel_path))
113
114        for full_path, rel_path in sorted(files):
115            with open(full_path, 'rb') as f:
116                full_code = f.read()
117                if sys.version_info >= (3,):
118                    full_code = full_code.decode('utf-8')
119
120            modname = rel_path.replace('.py', '').replace(os.sep, '.')
121            if modname == '__init__':
122                modname = module.__name__
123            else:
124                modname = '%s.%s' % (module.__name__, modname)
125
126            imports = set([])
127            module_node = ast.parse(full_code, filename=full_path)
128            walk_ast(module_node, modname, imports)
129
130            deps[modname] = imports
131
132        load_order = module.load_order()
133        prev = set([])
134        for mod in load_order:
135            self.assertEqual(True, mod in deps)
136            self.assertEqual((mod, set([])), (mod, deps[mod] - prev))
137            prev.add(mod)
138