1"""This script contains the actual auditing tests. 2 3It should not be imported directly, but should be run by the test_audit 4module with arguments identifying each test. 5 6""" 7 8import contextlib 9import sys 10 11 12class TestHook: 13 """Used in standard hook tests to collect any logged events. 14 15 Should be used in a with block to ensure that it has no impact 16 after the test completes. 17 """ 18 19 def __init__(self, raise_on_events=None, exc_type=RuntimeError): 20 self.raise_on_events = raise_on_events or () 21 self.exc_type = exc_type 22 self.seen = [] 23 self.closed = False 24 25 def __enter__(self, *a): 26 sys.addaudithook(self) 27 return self 28 29 def __exit__(self, *a): 30 self.close() 31 32 def close(self): 33 self.closed = True 34 35 @property 36 def seen_events(self): 37 return [i[0] for i in self.seen] 38 39 def __call__(self, event, args): 40 if self.closed: 41 return 42 self.seen.append((event, args)) 43 if event in self.raise_on_events: 44 raise self.exc_type("saw event " + event) 45 46 47# Simple helpers, since we are not in unittest here 48def assertEqual(x, y): 49 if x != y: 50 raise AssertionError(f"{x!r} should equal {y!r}") 51 52 53def assertIn(el, series): 54 if el not in series: 55 raise AssertionError(f"{el!r} should be in {series!r}") 56 57 58def assertNotIn(el, series): 59 if el in series: 60 raise AssertionError(f"{el!r} should not be in {series!r}") 61 62 63def assertSequenceEqual(x, y): 64 if len(x) != len(y): 65 raise AssertionError(f"{x!r} should equal {y!r}") 66 if any(ix != iy for ix, iy in zip(x, y)): 67 raise AssertionError(f"{x!r} should equal {y!r}") 68 69 70@contextlib.contextmanager 71def assertRaises(ex_type): 72 try: 73 yield 74 assert False, f"expected {ex_type}" 75 except BaseException as ex: 76 if isinstance(ex, AssertionError): 77 raise 78 assert type(ex) is ex_type, f"{ex} should be {ex_type}" 79 80 81def test_basic(): 82 with TestHook() as hook: 83 sys.audit("test_event", 1, 2, 3) 84 assertEqual(hook.seen[0][0], "test_event") 85 assertEqual(hook.seen[0][1], (1, 2, 3)) 86 87 88def test_block_add_hook(): 89 # Raising an exception should prevent a new hook from being added, 90 # but will not propagate out. 91 with TestHook(raise_on_events="sys.addaudithook") as hook1: 92 with TestHook() as hook2: 93 sys.audit("test_event") 94 assertIn("test_event", hook1.seen_events) 95 assertNotIn("test_event", hook2.seen_events) 96 97 98def test_block_add_hook_baseexception(): 99 # Raising BaseException will propagate out when adding a hook 100 with assertRaises(BaseException): 101 with TestHook( 102 raise_on_events="sys.addaudithook", exc_type=BaseException 103 ) as hook1: 104 # Adding this next hook should raise BaseException 105 with TestHook() as hook2: 106 pass 107 108 109def test_pickle(): 110 import pickle 111 112 class PicklePrint: 113 def __reduce_ex__(self, p): 114 return str, ("Pwned!",) 115 116 payload_1 = pickle.dumps(PicklePrint()) 117 payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3)) 118 119 # Before we add the hook, ensure our malicious pickle loads 120 assertEqual("Pwned!", pickle.loads(payload_1)) 121 122 with TestHook(raise_on_events="pickle.find_class") as hook: 123 with assertRaises(RuntimeError): 124 # With the hook enabled, loading globals is not allowed 125 pickle.loads(payload_1) 126 # pickles with no globals are okay 127 pickle.loads(payload_2) 128 129 130def test_monkeypatch(): 131 class A: 132 pass 133 134 class B: 135 pass 136 137 class C(A): 138 pass 139 140 a = A() 141 142 with TestHook() as hook: 143 # Catch name changes 144 C.__name__ = "X" 145 # Catch type changes 146 C.__bases__ = (B,) 147 # Ensure bypassing __setattr__ is still caught 148 type.__dict__["__bases__"].__set__(C, (B,)) 149 # Catch attribute replacement 150 C.__init__ = B.__init__ 151 # Catch attribute addition 152 C.new_attr = 123 153 # Catch class changes 154 a.__class__ = B 155 156 actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"] 157 assertSequenceEqual( 158 [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual 159 ) 160 161 162def test_open(): 163 # SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open() 164 try: 165 import ssl 166 167 load_dh_params = ssl.create_default_context().load_dh_params 168 except ImportError: 169 load_dh_params = None 170 171 # Try a range of "open" functions. 172 # All of them should fail 173 with TestHook(raise_on_events={"open"}) as hook: 174 for fn, *args in [ 175 (open, sys.argv[2], "r"), 176 (open, sys.executable, "rb"), 177 (open, 3, "wb"), 178 (open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1), 179 (load_dh_params, sys.argv[2]), 180 ]: 181 if not fn: 182 continue 183 with assertRaises(RuntimeError): 184 fn(*args) 185 186 actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]] 187 actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]] 188 assertSequenceEqual( 189 [ 190 i 191 for i in [ 192 (sys.argv[2], "r"), 193 (sys.executable, "r"), 194 (3, "w"), 195 (sys.argv[2], "w"), 196 (sys.argv[2], "rb") if load_dh_params else None, 197 ] 198 if i is not None 199 ], 200 actual_mode, 201 ) 202 assertSequenceEqual([], actual_flag) 203 204 205def test_cantrace(): 206 traced = [] 207 208 def trace(frame, event, *args): 209 if frame.f_code == TestHook.__call__.__code__: 210 traced.append(event) 211 212 old = sys.settrace(trace) 213 try: 214 with TestHook() as hook: 215 # No traced call 216 eval("1") 217 218 # No traced call 219 hook.__cantrace__ = False 220 eval("2") 221 222 # One traced call 223 hook.__cantrace__ = True 224 eval("3") 225 226 # Two traced calls (writing to private member, eval) 227 hook.__cantrace__ = 1 228 eval("4") 229 230 # One traced call (writing to private member) 231 hook.__cantrace__ = 0 232 finally: 233 sys.settrace(old) 234 235 assertSequenceEqual(["call"] * 4, traced) 236 237 238def test_mmap(): 239 import mmap 240 241 with TestHook() as hook: 242 mmap.mmap(-1, 8) 243 assertEqual(hook.seen[0][1][:2], (-1, 8)) 244 245 246def test_excepthook(): 247 def excepthook(exc_type, exc_value, exc_tb): 248 if exc_type is not RuntimeError: 249 sys.__excepthook__(exc_type, exc_value, exc_tb) 250 251 def hook(event, args): 252 if event == "sys.excepthook": 253 if not isinstance(args[2], args[1]): 254 raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})") 255 if args[0] != excepthook: 256 raise ValueError(f"Expected {args[0]} == {excepthook}") 257 print(event, repr(args[2])) 258 259 sys.addaudithook(hook) 260 sys.excepthook = excepthook 261 raise RuntimeError("fatal-error") 262 263 264def test_unraisablehook(): 265 from _testcapi import write_unraisable_exc 266 267 def unraisablehook(hookargs): 268 pass 269 270 def hook(event, args): 271 if event == "sys.unraisablehook": 272 if args[0] != unraisablehook: 273 raise ValueError(f"Expected {args[0]} == {unraisablehook}") 274 print(event, repr(args[1].exc_value), args[1].err_msg) 275 276 sys.addaudithook(hook) 277 sys.unraisablehook = unraisablehook 278 write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None) 279 280 281def test_winreg(): 282 from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE 283 284 def hook(event, args): 285 if not event.startswith("winreg."): 286 return 287 print(event, *args) 288 289 sys.addaudithook(hook) 290 291 k = OpenKey(HKEY_LOCAL_MACHINE, "Software") 292 EnumKey(k, 0) 293 try: 294 EnumKey(k, 10000) 295 except OSError: 296 pass 297 else: 298 raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail") 299 300 kv = k.Detach() 301 CloseKey(kv) 302 303 304def test_socket(): 305 import socket 306 307 def hook(event, args): 308 if event.startswith("socket."): 309 print(event, *args) 310 311 sys.addaudithook(hook) 312 313 socket.gethostname() 314 315 # Don't care if this fails, we just want the audit message 316 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 317 try: 318 # Don't care if this fails, we just want the audit message 319 sock.bind(('127.0.0.1', 8080)) 320 except Exception: 321 pass 322 finally: 323 sock.close() 324 325 326if __name__ == "__main__": 327 from test.support import suppress_msvcrt_asserts 328 329 suppress_msvcrt_asserts() 330 331 test = sys.argv[1] 332 globals()[test]() 333