1import ctypes 2import gc 3import sys 4import unittest 5from ctypes import POINTER, byref, c_void_p 6from ctypes.wintypes import BYTE, DWORD, WORD 7 8if sys.platform != "win32": 9 raise unittest.SkipTest("Windows-specific test") 10 11 12from _ctypes import COMError, CopyComPointer 13from ctypes import HRESULT 14 15 16COINIT_APARTMENTTHREADED = 0x2 17CLSCTX_SERVER = 5 18S_OK = 0 19OUT = 2 20TRUE = 1 21E_NOINTERFACE = -2147467262 22 23 24class GUID(ctypes.Structure): 25 # https://learn.microsoft.com/en-us/windows/win32/api/guiddef/ns-guiddef-guid 26 _fields_ = [ 27 ("Data1", DWORD), 28 ("Data2", WORD), 29 ("Data3", WORD), 30 ("Data4", BYTE * 8), 31 ] 32 33 34def create_proto_com_method(name, index, restype, *argtypes): 35 proto = ctypes.WINFUNCTYPE(restype, *argtypes) 36 37 def make_method(*args): 38 foreign_func = proto(index, name, *args) 39 40 def call(self, *args, **kwargs): 41 return foreign_func(self, *args, **kwargs) 42 43 return call 44 45 return make_method 46 47 48def create_guid(name): 49 guid = GUID() 50 # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-clsidfromstring 51 ole32.CLSIDFromString(name, byref(guid)) 52 return guid 53 54 55def is_equal_guid(guid1, guid2): 56 # https://learn.microsoft.com/en-us/windows/win32/api/objbase/nf-objbase-isequalguid 57 return ole32.IsEqualGUID(byref(guid1), byref(guid2)) 58 59 60ole32 = ctypes.oledll.ole32 61 62IID_IUnknown = create_guid("{00000000-0000-0000-C000-000000000046}") 63IID_IStream = create_guid("{0000000C-0000-0000-C000-000000000046}") 64IID_IPersist = create_guid("{0000010C-0000-0000-C000-000000000046}") 65CLSID_ShellLink = create_guid("{00021401-0000-0000-C000-000000000046}") 66 67# https://learn.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-queryinterface(refiid_void) 68proto_query_interface = create_proto_com_method( 69 "QueryInterface", 0, HRESULT, POINTER(GUID), POINTER(c_void_p) 70) 71# https://learn.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-addref 72proto_add_ref = create_proto_com_method("AddRef", 1, ctypes.c_long) 73# https://learn.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-release 74proto_release = create_proto_com_method("Release", 2, ctypes.c_long) 75# https://learn.microsoft.com/en-us/windows/win32/api/objidl/nf-objidl-ipersist-getclassid 76proto_get_class_id = create_proto_com_method( 77 "GetClassID", 3, HRESULT, POINTER(GUID) 78) 79 80 81def create_shelllink_persist(typ): 82 ppst = typ() 83 # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance 84 ole32.CoCreateInstance( 85 byref(CLSID_ShellLink), 86 None, 87 CLSCTX_SERVER, 88 byref(IID_IPersist), 89 byref(ppst), 90 ) 91 return ppst 92 93 94class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase): 95 def setUp(self): 96 # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex 97 ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED) 98 99 def tearDown(self): 100 # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-couninitialize 101 ole32.CoUninitialize() 102 gc.collect() 103 104 def test_without_paramflags_and_iid(self): 105 class IUnknown(c_void_p): 106 QueryInterface = proto_query_interface() 107 AddRef = proto_add_ref() 108 Release = proto_release() 109 110 class IPersist(IUnknown): 111 GetClassID = proto_get_class_id() 112 113 ppst = create_shelllink_persist(IPersist) 114 115 clsid = GUID() 116 hr_getclsid = ppst.GetClassID(byref(clsid)) 117 self.assertEqual(S_OK, hr_getclsid) 118 self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) 119 120 self.assertEqual(2, ppst.AddRef()) 121 self.assertEqual(3, ppst.AddRef()) 122 123 punk = IUnknown() 124 hr_qi = ppst.QueryInterface(IID_IUnknown, punk) 125 self.assertEqual(S_OK, hr_qi) 126 self.assertEqual(3, punk.Release()) 127 128 with self.assertRaises(OSError) as e: 129 punk.QueryInterface(IID_IStream, IUnknown()) 130 self.assertEqual(E_NOINTERFACE, e.exception.winerror) 131 132 self.assertEqual(2, ppst.Release()) 133 self.assertEqual(1, ppst.Release()) 134 self.assertEqual(0, ppst.Release()) 135 136 def test_with_paramflags_and_without_iid(self): 137 class IUnknown(c_void_p): 138 QueryInterface = proto_query_interface(None) 139 AddRef = proto_add_ref() 140 Release = proto_release() 141 142 class IPersist(IUnknown): 143 GetClassID = proto_get_class_id(((OUT, "pClassID"),)) 144 145 ppst = create_shelllink_persist(IPersist) 146 147 clsid = ppst.GetClassID() 148 self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) 149 150 punk = IUnknown() 151 hr_qi = ppst.QueryInterface(IID_IUnknown, punk) 152 self.assertEqual(S_OK, hr_qi) 153 self.assertEqual(1, punk.Release()) 154 155 with self.assertRaises(OSError) as e: 156 ppst.QueryInterface(IID_IStream, IUnknown()) 157 self.assertEqual(E_NOINTERFACE, e.exception.winerror) 158 159 self.assertEqual(0, ppst.Release()) 160 161 def test_with_paramflags_and_iid(self): 162 class IUnknown(c_void_p): 163 QueryInterface = proto_query_interface(None, IID_IUnknown) 164 AddRef = proto_add_ref() 165 Release = proto_release() 166 167 class IPersist(IUnknown): 168 GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist) 169 170 ppst = create_shelllink_persist(IPersist) 171 172 clsid = ppst.GetClassID() 173 self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) 174 175 punk = IUnknown() 176 hr_qi = ppst.QueryInterface(IID_IUnknown, punk) 177 self.assertEqual(S_OK, hr_qi) 178 self.assertEqual(1, punk.Release()) 179 180 with self.assertRaises(COMError) as e: 181 ppst.QueryInterface(IID_IStream, IUnknown()) 182 self.assertEqual(E_NOINTERFACE, e.exception.hresult) 183 184 self.assertEqual(0, ppst.Release()) 185 186 187class CopyComPointerTests(unittest.TestCase): 188 def setUp(self): 189 ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED) 190 191 class IUnknown(c_void_p): 192 QueryInterface = proto_query_interface(None, IID_IUnknown) 193 AddRef = proto_add_ref() 194 Release = proto_release() 195 196 class IPersist(IUnknown): 197 GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist) 198 199 self.IUnknown = IUnknown 200 self.IPersist = IPersist 201 202 def tearDown(self): 203 ole32.CoUninitialize() 204 gc.collect() 205 206 def test_both_are_null(self): 207 src = self.IPersist() 208 dst = self.IPersist() 209 210 hr = CopyComPointer(src, byref(dst)) 211 212 self.assertEqual(S_OK, hr) 213 214 self.assertIsNone(src.value) 215 self.assertIsNone(dst.value) 216 217 def test_src_is_nonnull_and_dest_is_null(self): 218 # The reference count of the COM pointer created by `CoCreateInstance` 219 # is initially 1. 220 src = create_shelllink_persist(self.IPersist) 221 dst = self.IPersist() 222 223 # `CopyComPointer` calls `AddRef` explicitly in the C implementation. 224 # The refcount of `src` is incremented from 1 to 2 here. 225 hr = CopyComPointer(src, byref(dst)) 226 227 self.assertEqual(S_OK, hr) 228 self.assertEqual(src.value, dst.value) 229 230 # This indicates that the refcount was 2 before the `Release` call. 231 self.assertEqual(1, src.Release()) 232 233 clsid = dst.GetClassID() 234 self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) 235 236 self.assertEqual(0, dst.Release()) 237 238 def test_src_is_null_and_dest_is_nonnull(self): 239 src = self.IPersist() 240 dst_orig = create_shelllink_persist(self.IPersist) 241 dst = self.IPersist() 242 CopyComPointer(dst_orig, byref(dst)) 243 self.assertEqual(1, dst_orig.Release()) 244 245 clsid = dst.GetClassID() 246 self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) 247 248 # This does NOT affects the refcount of `dst_orig`. 249 hr = CopyComPointer(src, byref(dst)) 250 251 self.assertEqual(S_OK, hr) 252 self.assertIsNone(dst.value) 253 254 with self.assertRaises(ValueError): 255 dst.GetClassID() # NULL COM pointer access 256 257 # This indicates that the refcount was 1 before the `Release` call. 258 self.assertEqual(0, dst_orig.Release()) 259 260 def test_both_are_nonnull(self): 261 src = create_shelllink_persist(self.IPersist) 262 dst_orig = create_shelllink_persist(self.IPersist) 263 dst = self.IPersist() 264 CopyComPointer(dst_orig, byref(dst)) 265 self.assertEqual(1, dst_orig.Release()) 266 267 self.assertEqual(dst.value, dst_orig.value) 268 self.assertNotEqual(src.value, dst.value) 269 270 hr = CopyComPointer(src, byref(dst)) 271 272 self.assertEqual(S_OK, hr) 273 self.assertEqual(src.value, dst.value) 274 self.assertNotEqual(dst.value, dst_orig.value) 275 276 self.assertEqual(1, src.Release()) 277 278 clsid = dst.GetClassID() 279 self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) 280 281 self.assertEqual(0, dst.Release()) 282 self.assertEqual(0, dst_orig.Release()) 283 284 285if __name__ == '__main__': 286 unittest.main() 287