1"""This script contains the actual auditing tests. 2 3It should not be imported directly, but should be run by the test_audit 4module with arguments identifying each test. 5 6""" 7 8import contextlib 9import sys 10 11 12class TestHook: 13 """Used in standard hook tests to collect any logged events. 14 15 Should be used in a with block to ensure that it has no impact 16 after the test completes. 17 """ 18 19 def __init__(self, raise_on_events=None, exc_type=RuntimeError): 20 self.raise_on_events = raise_on_events or () 21 self.exc_type = exc_type 22 self.seen = [] 23 self.closed = False 24 25 def __enter__(self, *a): 26 sys.addaudithook(self) 27 return self 28 29 def __exit__(self, *a): 30 self.close() 31 32 def close(self): 33 self.closed = True 34 35 @property 36 def seen_events(self): 37 return [i[0] for i in self.seen] 38 39 def __call__(self, event, args): 40 if self.closed: 41 return 42 self.seen.append((event, args)) 43 if event in self.raise_on_events: 44 raise self.exc_type("saw event " + event) 45 46 47class TestFinalizeHook: 48 """Used in the test_finalize_hooks function to ensure that hooks 49 are correctly cleaned up, that they are notified about the cleanup, 50 and are unable to prevent it. 51 """ 52 53 def __init__(self): 54 print("Created", id(self), file=sys.stdout, flush=True) 55 56 def __call__(self, event, args): 57 # Avoid recursion when we call id() below 58 if event == "builtins.id": 59 return 60 61 print(event, id(self), file=sys.stdout, flush=True) 62 63 if event == "cpython._PySys_ClearAuditHooks": 64 raise RuntimeError("Should be ignored") 65 elif event == "cpython.PyInterpreterState_Clear": 66 raise RuntimeError("Should be ignored") 67 68 69# Simple helpers, since we are not in unittest here 70def assertEqual(x, y): 71 if x != y: 72 raise AssertionError(f"{x!r} should equal {y!r}") 73 74 75def assertIn(el, series): 76 if el not in series: 77 raise AssertionError(f"{el!r} should be in {series!r}") 78 79 80def assertNotIn(el, series): 81 if el in series: 82 raise AssertionError(f"{el!r} should not be in {series!r}") 83 84 85def assertSequenceEqual(x, y): 86 if len(x) != len(y): 87 raise AssertionError(f"{x!r} should equal {y!r}") 88 if any(ix != iy for ix, iy in zip(x, y)): 89 raise AssertionError(f"{x!r} should equal {y!r}") 90 91 92@contextlib.contextmanager 93def assertRaises(ex_type): 94 try: 95 yield 96 assert False, f"expected {ex_type}" 97 except BaseException as ex: 98 if isinstance(ex, AssertionError): 99 raise 100 assert type(ex) is ex_type, f"{ex} should be {ex_type}" 101 102 103def test_basic(): 104 with TestHook() as hook: 105 sys.audit("test_event", 1, 2, 3) 106 assertEqual(hook.seen[0][0], "test_event") 107 assertEqual(hook.seen[0][1], (1, 2, 3)) 108 109 110def test_block_add_hook(): 111 # Raising an exception should prevent a new hook from being added, 112 # but will not propagate out. 113 with TestHook(raise_on_events="sys.addaudithook") as hook1: 114 with TestHook() as hook2: 115 sys.audit("test_event") 116 assertIn("test_event", hook1.seen_events) 117 assertNotIn("test_event", hook2.seen_events) 118 119 120def test_block_add_hook_baseexception(): 121 # Raising BaseException will propagate out when adding a hook 122 with assertRaises(BaseException): 123 with TestHook( 124 raise_on_events="sys.addaudithook", exc_type=BaseException 125 ) as hook1: 126 # Adding this next hook should raise BaseException 127 with TestHook() as hook2: 128 pass 129 130 131def test_finalize_hooks(): 132 sys.addaudithook(TestFinalizeHook()) 133 134 135def test_pickle(): 136 import pickle 137 138 class PicklePrint: 139 def __reduce_ex__(self, p): 140 return str, ("Pwned!",) 141 142 payload_1 = pickle.dumps(PicklePrint()) 143 payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3)) 144 145 # Before we add the hook, ensure our malicious pickle loads 146 assertEqual("Pwned!", pickle.loads(payload_1)) 147 148 with TestHook(raise_on_events="pickle.find_class") as hook: 149 with assertRaises(RuntimeError): 150 # With the hook enabled, loading globals is not allowed 151 pickle.loads(payload_1) 152 # pickles with no globals are okay 153 pickle.loads(payload_2) 154 155 156def test_monkeypatch(): 157 class A: 158 pass 159 160 class B: 161 pass 162 163 class C(A): 164 pass 165 166 a = A() 167 168 with TestHook() as hook: 169 # Catch name changes 170 C.__name__ = "X" 171 # Catch type changes 172 C.__bases__ = (B,) 173 # Ensure bypassing __setattr__ is still caught 174 type.__dict__["__bases__"].__set__(C, (B,)) 175 # Catch attribute replacement 176 C.__init__ = B.__init__ 177 # Catch attribute addition 178 C.new_attr = 123 179 # Catch class changes 180 a.__class__ = B 181 182 actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"] 183 assertSequenceEqual( 184 [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual 185 ) 186 187 188def test_open(): 189 # SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open() 190 try: 191 import ssl 192 193 load_dh_params = ssl.create_default_context().load_dh_params 194 except ImportError: 195 load_dh_params = None 196 197 # Try a range of "open" functions. 198 # All of them should fail 199 with TestHook(raise_on_events={"open"}) as hook: 200 for fn, *args in [ 201 (open, sys.argv[2], "r"), 202 (open, sys.executable, "rb"), 203 (open, 3, "wb"), 204 (open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1), 205 (load_dh_params, sys.argv[2]), 206 ]: 207 if not fn: 208 continue 209 with assertRaises(RuntimeError): 210 fn(*args) 211 212 actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]] 213 actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]] 214 assertSequenceEqual( 215 [ 216 i 217 for i in [ 218 (sys.argv[2], "r"), 219 (sys.executable, "r"), 220 (3, "w"), 221 (sys.argv[2], "w"), 222 (sys.argv[2], "rb") if load_dh_params else None, 223 ] 224 if i is not None 225 ], 226 actual_mode, 227 ) 228 assertSequenceEqual([], actual_flag) 229 230 231def test_cantrace(): 232 traced = [] 233 234 def trace(frame, event, *args): 235 if frame.f_code == TestHook.__call__.__code__: 236 traced.append(event) 237 238 old = sys.settrace(trace) 239 try: 240 with TestHook() as hook: 241 # No traced call 242 eval("1") 243 244 # No traced call 245 hook.__cantrace__ = False 246 eval("2") 247 248 # One traced call 249 hook.__cantrace__ = True 250 eval("3") 251 252 # Two traced calls (writing to private member, eval) 253 hook.__cantrace__ = 1 254 eval("4") 255 256 # One traced call (writing to private member) 257 hook.__cantrace__ = 0 258 finally: 259 sys.settrace(old) 260 261 assertSequenceEqual(["call"] * 4, traced) 262 263 264def test_mmap(): 265 import mmap 266 267 with TestHook() as hook: 268 mmap.mmap(-1, 8) 269 assertEqual(hook.seen[0][1][:2], (-1, 8)) 270 271 272def test_excepthook(): 273 def excepthook(exc_type, exc_value, exc_tb): 274 if exc_type is not RuntimeError: 275 sys.__excepthook__(exc_type, exc_value, exc_tb) 276 277 def hook(event, args): 278 if event == "sys.excepthook": 279 if not isinstance(args[2], args[1]): 280 raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})") 281 if args[0] != excepthook: 282 raise ValueError(f"Expected {args[0]} == {excepthook}") 283 print(event, repr(args[2])) 284 285 sys.addaudithook(hook) 286 sys.excepthook = excepthook 287 raise RuntimeError("fatal-error") 288 289 290def test_unraisablehook(): 291 from _testcapi import write_unraisable_exc 292 293 def unraisablehook(hookargs): 294 pass 295 296 def hook(event, args): 297 if event == "sys.unraisablehook": 298 if args[0] != unraisablehook: 299 raise ValueError(f"Expected {args[0]} == {unraisablehook}") 300 print(event, repr(args[1].exc_value), args[1].err_msg) 301 302 sys.addaudithook(hook) 303 sys.unraisablehook = unraisablehook 304 write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None) 305 306 307def test_winreg(): 308 from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE 309 310 def hook(event, args): 311 if not event.startswith("winreg."): 312 return 313 print(event, *args) 314 315 sys.addaudithook(hook) 316 317 k = OpenKey(HKEY_LOCAL_MACHINE, "Software") 318 EnumKey(k, 0) 319 try: 320 EnumKey(k, 10000) 321 except OSError: 322 pass 323 else: 324 raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail") 325 326 kv = k.Detach() 327 CloseKey(kv) 328 329 330if __name__ == "__main__": 331 from test.libregrtest.setup import suppress_msvcrt_asserts 332 333 suppress_msvcrt_asserts(False) 334 335 test = sys.argv[1] 336 globals()[test]() 337