• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import ctypes
2import gc
3import sys
4import unittest
5from ctypes import POINTER, byref, c_void_p
6from ctypes.wintypes import BYTE, DWORD, WORD
7
8if sys.platform != "win32":
9    raise unittest.SkipTest("Windows-specific test")
10
11
12from _ctypes import COMError, CopyComPointer
13from ctypes import HRESULT
14
15
16COINIT_APARTMENTTHREADED = 0x2
17CLSCTX_SERVER = 5
18S_OK = 0
19OUT = 2
20TRUE = 1
21E_NOINTERFACE = -2147467262
22
23
24class GUID(ctypes.Structure):
25    # https://learn.microsoft.com/en-us/windows/win32/api/guiddef/ns-guiddef-guid
26    _fields_ = [
27        ("Data1", DWORD),
28        ("Data2", WORD),
29        ("Data3", WORD),
30        ("Data4", BYTE * 8),
31    ]
32
33
34def create_proto_com_method(name, index, restype, *argtypes):
35    proto = ctypes.WINFUNCTYPE(restype, *argtypes)
36
37    def make_method(*args):
38        foreign_func = proto(index, name, *args)
39
40        def call(self, *args, **kwargs):
41            return foreign_func(self, *args, **kwargs)
42
43        return call
44
45    return make_method
46
47
48def create_guid(name):
49    guid = GUID()
50    # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-clsidfromstring
51    ole32.CLSIDFromString(name, byref(guid))
52    return guid
53
54
55def is_equal_guid(guid1, guid2):
56    # https://learn.microsoft.com/en-us/windows/win32/api/objbase/nf-objbase-isequalguid
57    return ole32.IsEqualGUID(byref(guid1), byref(guid2))
58
59
60ole32 = ctypes.oledll.ole32
61
62IID_IUnknown = create_guid("{00000000-0000-0000-C000-000000000046}")
63IID_IStream = create_guid("{0000000C-0000-0000-C000-000000000046}")
64IID_IPersist = create_guid("{0000010C-0000-0000-C000-000000000046}")
65CLSID_ShellLink = create_guid("{00021401-0000-0000-C000-000000000046}")
66
67# https://learn.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-queryinterface(refiid_void)
68proto_query_interface = create_proto_com_method(
69    "QueryInterface", 0, HRESULT, POINTER(GUID), POINTER(c_void_p)
70)
71# https://learn.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-addref
72proto_add_ref = create_proto_com_method("AddRef", 1, ctypes.c_long)
73# https://learn.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-release
74proto_release = create_proto_com_method("Release", 2, ctypes.c_long)
75# https://learn.microsoft.com/en-us/windows/win32/api/objidl/nf-objidl-ipersist-getclassid
76proto_get_class_id = create_proto_com_method(
77    "GetClassID", 3, HRESULT, POINTER(GUID)
78)
79
80
81def create_shelllink_persist(typ):
82    ppst = typ()
83    # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
84    ole32.CoCreateInstance(
85        byref(CLSID_ShellLink),
86        None,
87        CLSCTX_SERVER,
88        byref(IID_IPersist),
89        byref(ppst),
90    )
91    return ppst
92
93
94class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase):
95    def setUp(self):
96        # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex
97        ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED)
98
99    def tearDown(self):
100        # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-couninitialize
101        ole32.CoUninitialize()
102        gc.collect()
103
104    def test_without_paramflags_and_iid(self):
105        class IUnknown(c_void_p):
106            QueryInterface = proto_query_interface()
107            AddRef = proto_add_ref()
108            Release = proto_release()
109
110        class IPersist(IUnknown):
111            GetClassID = proto_get_class_id()
112
113        ppst = create_shelllink_persist(IPersist)
114
115        clsid = GUID()
116        hr_getclsid = ppst.GetClassID(byref(clsid))
117        self.assertEqual(S_OK, hr_getclsid)
118        self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
119
120        self.assertEqual(2, ppst.AddRef())
121        self.assertEqual(3, ppst.AddRef())
122
123        punk = IUnknown()
124        hr_qi = ppst.QueryInterface(IID_IUnknown, punk)
125        self.assertEqual(S_OK, hr_qi)
126        self.assertEqual(3, punk.Release())
127
128        with self.assertRaises(OSError) as e:
129            punk.QueryInterface(IID_IStream, IUnknown())
130        self.assertEqual(E_NOINTERFACE, e.exception.winerror)
131
132        self.assertEqual(2, ppst.Release())
133        self.assertEqual(1, ppst.Release())
134        self.assertEqual(0, ppst.Release())
135
136    def test_with_paramflags_and_without_iid(self):
137        class IUnknown(c_void_p):
138            QueryInterface = proto_query_interface(None)
139            AddRef = proto_add_ref()
140            Release = proto_release()
141
142        class IPersist(IUnknown):
143            GetClassID = proto_get_class_id(((OUT, "pClassID"),))
144
145        ppst = create_shelllink_persist(IPersist)
146
147        clsid = ppst.GetClassID()
148        self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
149
150        punk = IUnknown()
151        hr_qi = ppst.QueryInterface(IID_IUnknown, punk)
152        self.assertEqual(S_OK, hr_qi)
153        self.assertEqual(1, punk.Release())
154
155        with self.assertRaises(OSError) as e:
156            ppst.QueryInterface(IID_IStream, IUnknown())
157        self.assertEqual(E_NOINTERFACE, e.exception.winerror)
158
159        self.assertEqual(0, ppst.Release())
160
161    def test_with_paramflags_and_iid(self):
162        class IUnknown(c_void_p):
163            QueryInterface = proto_query_interface(None, IID_IUnknown)
164            AddRef = proto_add_ref()
165            Release = proto_release()
166
167        class IPersist(IUnknown):
168            GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist)
169
170        ppst = create_shelllink_persist(IPersist)
171
172        clsid = ppst.GetClassID()
173        self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
174
175        punk = IUnknown()
176        hr_qi = ppst.QueryInterface(IID_IUnknown, punk)
177        self.assertEqual(S_OK, hr_qi)
178        self.assertEqual(1, punk.Release())
179
180        with self.assertRaises(COMError) as e:
181            ppst.QueryInterface(IID_IStream, IUnknown())
182        self.assertEqual(E_NOINTERFACE, e.exception.hresult)
183
184        self.assertEqual(0, ppst.Release())
185
186
187class CopyComPointerTests(unittest.TestCase):
188    def setUp(self):
189        ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED)
190
191        class IUnknown(c_void_p):
192            QueryInterface = proto_query_interface(None, IID_IUnknown)
193            AddRef = proto_add_ref()
194            Release = proto_release()
195
196        class IPersist(IUnknown):
197            GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist)
198
199        self.IUnknown = IUnknown
200        self.IPersist = IPersist
201
202    def tearDown(self):
203        ole32.CoUninitialize()
204        gc.collect()
205
206    def test_both_are_null(self):
207        src = self.IPersist()
208        dst = self.IPersist()
209
210        hr = CopyComPointer(src, byref(dst))
211
212        self.assertEqual(S_OK, hr)
213
214        self.assertIsNone(src.value)
215        self.assertIsNone(dst.value)
216
217    def test_src_is_nonnull_and_dest_is_null(self):
218        # The reference count of the COM pointer created by `CoCreateInstance`
219        # is initially 1.
220        src = create_shelllink_persist(self.IPersist)
221        dst = self.IPersist()
222
223        # `CopyComPointer` calls `AddRef` explicitly in the C implementation.
224        # The refcount of `src` is incremented from 1 to 2 here.
225        hr = CopyComPointer(src, byref(dst))
226
227        self.assertEqual(S_OK, hr)
228        self.assertEqual(src.value, dst.value)
229
230        # This indicates that the refcount was 2 before the `Release` call.
231        self.assertEqual(1, src.Release())
232
233        clsid = dst.GetClassID()
234        self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
235
236        self.assertEqual(0, dst.Release())
237
238    def test_src_is_null_and_dest_is_nonnull(self):
239        src = self.IPersist()
240        dst_orig = create_shelllink_persist(self.IPersist)
241        dst = self.IPersist()
242        CopyComPointer(dst_orig, byref(dst))
243        self.assertEqual(1, dst_orig.Release())
244
245        clsid = dst.GetClassID()
246        self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
247
248        # This does NOT affects the refcount of `dst_orig`.
249        hr = CopyComPointer(src, byref(dst))
250
251        self.assertEqual(S_OK, hr)
252        self.assertIsNone(dst.value)
253
254        with self.assertRaises(ValueError):
255            dst.GetClassID()  # NULL COM pointer access
256
257        # This indicates that the refcount was 1 before the `Release` call.
258        self.assertEqual(0, dst_orig.Release())
259
260    def test_both_are_nonnull(self):
261        src = create_shelllink_persist(self.IPersist)
262        dst_orig = create_shelllink_persist(self.IPersist)
263        dst = self.IPersist()
264        CopyComPointer(dst_orig, byref(dst))
265        self.assertEqual(1, dst_orig.Release())
266
267        self.assertEqual(dst.value, dst_orig.value)
268        self.assertNotEqual(src.value, dst.value)
269
270        hr = CopyComPointer(src, byref(dst))
271
272        self.assertEqual(S_OK, hr)
273        self.assertEqual(src.value, dst.value)
274        self.assertNotEqual(dst.value, dst_orig.value)
275
276        self.assertEqual(1, src.Release())
277
278        clsid = dst.GetClassID()
279        self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
280
281        self.assertEqual(0, dst.Release())
282        self.assertEqual(0, dst_orig.Release())
283
284
285if __name__ == '__main__':
286    unittest.main()
287