• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""This script contains the actual auditing tests.
2
3It should not be imported directly, but should be run by the test_audit
4module with arguments identifying each test.
5
6"""
7
8import contextlib
9import sys
10
11
12class TestHook:
13    """Used in standard hook tests to collect any logged events.
14
15    Should be used in a with block to ensure that it has no impact
16    after the test completes.
17    """
18
19    def __init__(self, raise_on_events=None, exc_type=RuntimeError):
20        self.raise_on_events = raise_on_events or ()
21        self.exc_type = exc_type
22        self.seen = []
23        self.closed = False
24
25    def __enter__(self, *a):
26        sys.addaudithook(self)
27        return self
28
29    def __exit__(self, *a):
30        self.close()
31
32    def close(self):
33        self.closed = True
34
35    @property
36    def seen_events(self):
37        return [i[0] for i in self.seen]
38
39    def __call__(self, event, args):
40        if self.closed:
41            return
42        self.seen.append((event, args))
43        if event in self.raise_on_events:
44            raise self.exc_type("saw event " + event)
45
46
47class TestFinalizeHook:
48    """Used in the test_finalize_hooks function to ensure that hooks
49    are correctly cleaned up, that they are notified about the cleanup,
50    and are unable to prevent it.
51    """
52
53    def __init__(self):
54        print("Created", id(self), file=sys.stdout, flush=True)
55
56    def __call__(self, event, args):
57        # Avoid recursion when we call id() below
58        if event == "builtins.id":
59            return
60
61        print(event, id(self), file=sys.stdout, flush=True)
62
63        if event == "cpython._PySys_ClearAuditHooks":
64            raise RuntimeError("Should be ignored")
65        elif event == "cpython.PyInterpreterState_Clear":
66            raise RuntimeError("Should be ignored")
67
68
69# Simple helpers, since we are not in unittest here
70def assertEqual(x, y):
71    if x != y:
72        raise AssertionError(f"{x!r} should equal {y!r}")
73
74
75def assertIn(el, series):
76    if el not in series:
77        raise AssertionError(f"{el!r} should be in {series!r}")
78
79
80def assertNotIn(el, series):
81    if el in series:
82        raise AssertionError(f"{el!r} should not be in {series!r}")
83
84
85def assertSequenceEqual(x, y):
86    if len(x) != len(y):
87        raise AssertionError(f"{x!r} should equal {y!r}")
88    if any(ix != iy for ix, iy in zip(x, y)):
89        raise AssertionError(f"{x!r} should equal {y!r}")
90
91
92@contextlib.contextmanager
93def assertRaises(ex_type):
94    try:
95        yield
96        assert False, f"expected {ex_type}"
97    except BaseException as ex:
98        if isinstance(ex, AssertionError):
99            raise
100        assert type(ex) is ex_type, f"{ex} should be {ex_type}"
101
102
103def test_basic():
104    with TestHook() as hook:
105        sys.audit("test_event", 1, 2, 3)
106        assertEqual(hook.seen[0][0], "test_event")
107        assertEqual(hook.seen[0][1], (1, 2, 3))
108
109
110def test_block_add_hook():
111    # Raising an exception should prevent a new hook from being added,
112    # but will not propagate out.
113    with TestHook(raise_on_events="sys.addaudithook") as hook1:
114        with TestHook() as hook2:
115            sys.audit("test_event")
116            assertIn("test_event", hook1.seen_events)
117            assertNotIn("test_event", hook2.seen_events)
118
119
120def test_block_add_hook_baseexception():
121    # Raising BaseException will propagate out when adding a hook
122    with assertRaises(BaseException):
123        with TestHook(
124            raise_on_events="sys.addaudithook", exc_type=BaseException
125        ) as hook1:
126            # Adding this next hook should raise BaseException
127            with TestHook() as hook2:
128                pass
129
130
131def test_finalize_hooks():
132    sys.addaudithook(TestFinalizeHook())
133
134
135def test_pickle():
136    import pickle
137
138    class PicklePrint:
139        def __reduce_ex__(self, p):
140            return str, ("Pwned!",)
141
142    payload_1 = pickle.dumps(PicklePrint())
143    payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))
144
145    # Before we add the hook, ensure our malicious pickle loads
146    assertEqual("Pwned!", pickle.loads(payload_1))
147
148    with TestHook(raise_on_events="pickle.find_class") as hook:
149        with assertRaises(RuntimeError):
150            # With the hook enabled, loading globals is not allowed
151            pickle.loads(payload_1)
152        # pickles with no globals are okay
153        pickle.loads(payload_2)
154
155
156def test_monkeypatch():
157    class A:
158        pass
159
160    class B:
161        pass
162
163    class C(A):
164        pass
165
166    a = A()
167
168    with TestHook() as hook:
169        # Catch name changes
170        C.__name__ = "X"
171        # Catch type changes
172        C.__bases__ = (B,)
173        # Ensure bypassing __setattr__ is still caught
174        type.__dict__["__bases__"].__set__(C, (B,))
175        # Catch attribute replacement
176        C.__init__ = B.__init__
177        # Catch attribute addition
178        C.new_attr = 123
179        # Catch class changes
180        a.__class__ = B
181
182    actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"]
183    assertSequenceEqual(
184        [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual
185    )
186
187
188def test_open():
189    # SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open()
190    try:
191        import ssl
192
193        load_dh_params = ssl.create_default_context().load_dh_params
194    except ImportError:
195        load_dh_params = None
196
197    # Try a range of "open" functions.
198    # All of them should fail
199    with TestHook(raise_on_events={"open"}) as hook:
200        for fn, *args in [
201            (open, sys.argv[2], "r"),
202            (open, sys.executable, "rb"),
203            (open, 3, "wb"),
204            (open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1),
205            (load_dh_params, sys.argv[2]),
206        ]:
207            if not fn:
208                continue
209            with assertRaises(RuntimeError):
210                fn(*args)
211
212    actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]]
213    actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]]
214    assertSequenceEqual(
215        [
216            i
217            for i in [
218                (sys.argv[2], "r"),
219                (sys.executable, "r"),
220                (3, "w"),
221                (sys.argv[2], "w"),
222                (sys.argv[2], "rb") if load_dh_params else None,
223            ]
224            if i is not None
225        ],
226        actual_mode,
227    )
228    assertSequenceEqual([], actual_flag)
229
230
231def test_cantrace():
232    traced = []
233
234    def trace(frame, event, *args):
235        if frame.f_code == TestHook.__call__.__code__:
236            traced.append(event)
237
238    old = sys.settrace(trace)
239    try:
240        with TestHook() as hook:
241            # No traced call
242            eval("1")
243
244            # No traced call
245            hook.__cantrace__ = False
246            eval("2")
247
248            # One traced call
249            hook.__cantrace__ = True
250            eval("3")
251
252            # Two traced calls (writing to private member, eval)
253            hook.__cantrace__ = 1
254            eval("4")
255
256            # One traced call (writing to private member)
257            hook.__cantrace__ = 0
258    finally:
259        sys.settrace(old)
260
261    assertSequenceEqual(["call"] * 4, traced)
262
263
264def test_mmap():
265    import mmap
266
267    with TestHook() as hook:
268        mmap.mmap(-1, 8)
269        assertEqual(hook.seen[0][1][:2], (-1, 8))
270
271
272def test_excepthook():
273    def excepthook(exc_type, exc_value, exc_tb):
274        if exc_type is not RuntimeError:
275            sys.__excepthook__(exc_type, exc_value, exc_tb)
276
277    def hook(event, args):
278        if event == "sys.excepthook":
279            if not isinstance(args[2], args[1]):
280                raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})")
281            if args[0] != excepthook:
282                raise ValueError(f"Expected {args[0]} == {excepthook}")
283            print(event, repr(args[2]))
284
285    sys.addaudithook(hook)
286    sys.excepthook = excepthook
287    raise RuntimeError("fatal-error")
288
289
290def test_unraisablehook():
291    from _testcapi import write_unraisable_exc
292
293    def unraisablehook(hookargs):
294        pass
295
296    def hook(event, args):
297        if event == "sys.unraisablehook":
298            if args[0] != unraisablehook:
299                raise ValueError(f"Expected {args[0]} == {unraisablehook}")
300            print(event, repr(args[1].exc_value), args[1].err_msg)
301
302    sys.addaudithook(hook)
303    sys.unraisablehook = unraisablehook
304    write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None)
305
306
307def test_winreg():
308    from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE
309
310    def hook(event, args):
311        if not event.startswith("winreg."):
312            return
313        print(event, *args)
314
315    sys.addaudithook(hook)
316
317    k = OpenKey(HKEY_LOCAL_MACHINE, "Software")
318    EnumKey(k, 0)
319    try:
320        EnumKey(k, 10000)
321    except OSError:
322        pass
323    else:
324        raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail")
325
326    kv = k.Detach()
327    CloseKey(kv)
328
329
330if __name__ == "__main__":
331    from test.libregrtest.setup import suppress_msvcrt_asserts
332
333    suppress_msvcrt_asserts(False)
334
335    test = sys.argv[1]
336    globals()[test]()
337