1"""This script contains the actual auditing tests. 2 3It should not be imported directly, but should be run by the test_audit 4module with arguments identifying each test. 5 6""" 7 8import contextlib 9import os 10import sys 11 12 13class TestHook: 14 """Used in standard hook tests to collect any logged events. 15 16 Should be used in a with block to ensure that it has no impact 17 after the test completes. 18 """ 19 20 def __init__(self, raise_on_events=None, exc_type=RuntimeError): 21 self.raise_on_events = raise_on_events or () 22 self.exc_type = exc_type 23 self.seen = [] 24 self.closed = False 25 26 def __enter__(self, *a): 27 sys.addaudithook(self) 28 return self 29 30 def __exit__(self, *a): 31 self.close() 32 33 def close(self): 34 self.closed = True 35 36 @property 37 def seen_events(self): 38 return [i[0] for i in self.seen] 39 40 def __call__(self, event, args): 41 if self.closed: 42 return 43 self.seen.append((event, args)) 44 if event in self.raise_on_events: 45 raise self.exc_type("saw event " + event) 46 47 48# Simple helpers, since we are not in unittest here 49def assertEqual(x, y): 50 if x != y: 51 raise AssertionError(f"{x!r} should equal {y!r}") 52 53 54def assertIn(el, series): 55 if el not in series: 56 raise AssertionError(f"{el!r} should be in {series!r}") 57 58 59def assertNotIn(el, series): 60 if el in series: 61 raise AssertionError(f"{el!r} should not be in {series!r}") 62 63 64def assertSequenceEqual(x, y): 65 if len(x) != len(y): 66 raise AssertionError(f"{x!r} should equal {y!r}") 67 if any(ix != iy for ix, iy in zip(x, y)): 68 raise AssertionError(f"{x!r} should equal {y!r}") 69 70 71@contextlib.contextmanager 72def assertRaises(ex_type): 73 try: 74 yield 75 assert False, f"expected {ex_type}" 76 except BaseException as ex: 77 if isinstance(ex, AssertionError): 78 raise 79 assert type(ex) is ex_type, f"{ex} should be {ex_type}" 80 81 82def test_basic(): 83 with TestHook() as hook: 84 sys.audit("test_event", 1, 2, 3) 85 assertEqual(hook.seen[0][0], "test_event") 86 assertEqual(hook.seen[0][1], (1, 2, 3)) 87 88 89def test_block_add_hook(): 90 # Raising an exception should prevent a new hook from being added, 91 # but will not propagate out. 92 with TestHook(raise_on_events="sys.addaudithook") as hook1: 93 with TestHook() as hook2: 94 sys.audit("test_event") 95 assertIn("test_event", hook1.seen_events) 96 assertNotIn("test_event", hook2.seen_events) 97 98 99def test_block_add_hook_baseexception(): 100 # Raising BaseException will propagate out when adding a hook 101 with assertRaises(BaseException): 102 with TestHook( 103 raise_on_events="sys.addaudithook", exc_type=BaseException 104 ) as hook1: 105 # Adding this next hook should raise BaseException 106 with TestHook() as hook2: 107 pass 108 109 110def test_marshal(): 111 import marshal 112 o = ("a", "b", "c", 1, 2, 3) 113 payload = marshal.dumps(o) 114 115 with TestHook() as hook: 116 assertEqual(o, marshal.loads(marshal.dumps(o))) 117 118 try: 119 with open("test-marshal.bin", "wb") as f: 120 marshal.dump(o, f) 121 with open("test-marshal.bin", "rb") as f: 122 assertEqual(o, marshal.load(f)) 123 finally: 124 os.unlink("test-marshal.bin") 125 126 actual = [(a[0], a[1]) for e, a in hook.seen if e == "marshal.dumps"] 127 assertSequenceEqual(actual, [(o, marshal.version)] * 2) 128 129 actual = [a[0] for e, a in hook.seen if e == "marshal.loads"] 130 assertSequenceEqual(actual, [payload]) 131 132 actual = [e for e, a in hook.seen if e == "marshal.load"] 133 assertSequenceEqual(actual, ["marshal.load"]) 134 135 136def test_pickle(): 137 import pickle 138 139 class PicklePrint: 140 def __reduce_ex__(self, p): 141 return str, ("Pwned!",) 142 143 payload_1 = pickle.dumps(PicklePrint()) 144 payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3)) 145 146 # Before we add the hook, ensure our malicious pickle loads 147 assertEqual("Pwned!", pickle.loads(payload_1)) 148 149 with TestHook(raise_on_events="pickle.find_class") as hook: 150 with assertRaises(RuntimeError): 151 # With the hook enabled, loading globals is not allowed 152 pickle.loads(payload_1) 153 # pickles with no globals are okay 154 pickle.loads(payload_2) 155 156 157def test_monkeypatch(): 158 class A: 159 pass 160 161 class B: 162 pass 163 164 class C(A): 165 pass 166 167 a = A() 168 169 with TestHook() as hook: 170 # Catch name changes 171 C.__name__ = "X" 172 # Catch type changes 173 C.__bases__ = (B,) 174 # Ensure bypassing __setattr__ is still caught 175 type.__dict__["__bases__"].__set__(C, (B,)) 176 # Catch attribute replacement 177 C.__init__ = B.__init__ 178 # Catch attribute addition 179 C.new_attr = 123 180 # Catch class changes 181 a.__class__ = B 182 183 actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"] 184 assertSequenceEqual( 185 [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual 186 ) 187 188 189def test_open(): 190 # SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open() 191 try: 192 import ssl 193 194 load_dh_params = ssl.create_default_context().load_dh_params 195 except ImportError: 196 load_dh_params = None 197 198 # Try a range of "open" functions. 199 # All of them should fail 200 with TestHook(raise_on_events={"open"}) as hook: 201 for fn, *args in [ 202 (open, sys.argv[2], "r"), 203 (open, sys.executable, "rb"), 204 (open, 3, "wb"), 205 (open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1), 206 (load_dh_params, sys.argv[2]), 207 ]: 208 if not fn: 209 continue 210 with assertRaises(RuntimeError): 211 fn(*args) 212 213 actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]] 214 actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]] 215 assertSequenceEqual( 216 [ 217 i 218 for i in [ 219 (sys.argv[2], "r"), 220 (sys.executable, "r"), 221 (3, "w"), 222 (sys.argv[2], "w"), 223 (sys.argv[2], "rb") if load_dh_params else None, 224 ] 225 if i is not None 226 ], 227 actual_mode, 228 ) 229 assertSequenceEqual([], actual_flag) 230 231 232def test_cantrace(): 233 traced = [] 234 235 def trace(frame, event, *args): 236 if frame.f_code == TestHook.__call__.__code__: 237 traced.append(event) 238 239 old = sys.settrace(trace) 240 try: 241 with TestHook() as hook: 242 # No traced call 243 eval("1") 244 245 # No traced call 246 hook.__cantrace__ = False 247 eval("2") 248 249 # One traced call 250 hook.__cantrace__ = True 251 eval("3") 252 253 # Two traced calls (writing to private member, eval) 254 hook.__cantrace__ = 1 255 eval("4") 256 257 # One traced call (writing to private member) 258 hook.__cantrace__ = 0 259 finally: 260 sys.settrace(old) 261 262 assertSequenceEqual(["call"] * 4, traced) 263 264 265def test_mmap(): 266 import mmap 267 268 with TestHook() as hook: 269 mmap.mmap(-1, 8) 270 assertEqual(hook.seen[0][1][:2], (-1, 8)) 271 272 273def test_excepthook(): 274 def excepthook(exc_type, exc_value, exc_tb): 275 if exc_type is not RuntimeError: 276 sys.__excepthook__(exc_type, exc_value, exc_tb) 277 278 def hook(event, args): 279 if event == "sys.excepthook": 280 if not isinstance(args[2], args[1]): 281 raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})") 282 if args[0] != excepthook: 283 raise ValueError(f"Expected {args[0]} == {excepthook}") 284 print(event, repr(args[2])) 285 286 sys.addaudithook(hook) 287 sys.excepthook = excepthook 288 raise RuntimeError("fatal-error") 289 290 291def test_unraisablehook(): 292 from _testcapi import write_unraisable_exc 293 294 def unraisablehook(hookargs): 295 pass 296 297 def hook(event, args): 298 if event == "sys.unraisablehook": 299 if args[0] != unraisablehook: 300 raise ValueError(f"Expected {args[0]} == {unraisablehook}") 301 print(event, repr(args[1].exc_value), args[1].err_msg) 302 303 sys.addaudithook(hook) 304 sys.unraisablehook = unraisablehook 305 write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None) 306 307 308def test_winreg(): 309 from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE 310 311 def hook(event, args): 312 if not event.startswith("winreg."): 313 return 314 print(event, *args) 315 316 sys.addaudithook(hook) 317 318 k = OpenKey(HKEY_LOCAL_MACHINE, "Software") 319 EnumKey(k, 0) 320 try: 321 EnumKey(k, 10000) 322 except OSError: 323 pass 324 else: 325 raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail") 326 327 kv = k.Detach() 328 CloseKey(kv) 329 330 331def test_socket(): 332 import socket 333 334 def hook(event, args): 335 if event.startswith("socket."): 336 print(event, *args) 337 338 sys.addaudithook(hook) 339 340 socket.gethostname() 341 342 # Don't care if this fails, we just want the audit message 343 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 344 try: 345 # Don't care if this fails, we just want the audit message 346 sock.bind(('127.0.0.1', 8080)) 347 except Exception: 348 pass 349 finally: 350 sock.close() 351 352 353def test_gc(): 354 import gc 355 356 def hook(event, args): 357 if event.startswith("gc."): 358 print(event, *args) 359 360 sys.addaudithook(hook) 361 362 gc.get_objects(generation=1) 363 364 x = object() 365 y = [x] 366 367 gc.get_referrers(x) 368 gc.get_referents(y) 369 370 371def test_http_client(): 372 import http.client 373 374 def hook(event, args): 375 if event.startswith("http.client."): 376 print(event, *args[1:]) 377 378 sys.addaudithook(hook) 379 380 conn = http.client.HTTPConnection('www.python.org') 381 try: 382 conn.request('GET', '/') 383 except OSError: 384 print('http.client.send', '[cannot send]') 385 finally: 386 conn.close() 387 388 389def test_sqlite3(): 390 import sqlite3 391 392 def hook(event, *args): 393 if event.startswith("sqlite3."): 394 print(event, *args) 395 396 sys.addaudithook(hook) 397 cx1 = sqlite3.connect(":memory:") 398 cx2 = sqlite3.Connection(":memory:") 399 400 # Configured without --enable-loadable-sqlite-extensions 401 if hasattr(sqlite3.Connection, "enable_load_extension"): 402 cx1.enable_load_extension(False) 403 try: 404 cx1.load_extension("test") 405 except sqlite3.OperationalError: 406 pass 407 else: 408 raise RuntimeError("Expected sqlite3.load_extension to fail") 409 410 411if __name__ == "__main__": 412 from test.support import suppress_msvcrt_asserts 413 414 suppress_msvcrt_asserts() 415 416 test = sys.argv[1] 417 globals()[test]() 418