• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2import pytest
3
4import env  # noqa: F401
5
6m = pytest.importorskip("pybind11_tests.virtual_functions")
7from pybind11_tests import ConstructorStats  # noqa: E402
8
9
10def test_override(capture, msg):
11    class ExtendedExampleVirt(m.ExampleVirt):
12        def __init__(self, state):
13            super(ExtendedExampleVirt, self).__init__(state + 1)
14            self.data = "Hello world"
15
16        def run(self, value):
17            print("ExtendedExampleVirt::run(%i), calling parent.." % value)
18            return super(ExtendedExampleVirt, self).run(value + 1)
19
20        def run_bool(self):
21            print("ExtendedExampleVirt::run_bool()")
22            return False
23
24        def get_string1(self):
25            return "override1"
26
27        def pure_virtual(self):
28            print("ExtendedExampleVirt::pure_virtual(): %s" % self.data)
29
30    class ExtendedExampleVirt2(ExtendedExampleVirt):
31        def __init__(self, state):
32            super(ExtendedExampleVirt2, self).__init__(state + 1)
33
34        def get_string2(self):
35            return "override2"
36
37    ex12 = m.ExampleVirt(10)
38    with capture:
39        assert m.runExampleVirt(ex12, 20) == 30
40    assert (
41        capture
42        == """
43        Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)
44    """  # noqa: E501 line too long
45    )
46
47    with pytest.raises(RuntimeError) as excinfo:
48        m.runExampleVirtVirtual(ex12)
49    assert (
50        msg(excinfo.value)
51        == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"'
52    )
53
54    ex12p = ExtendedExampleVirt(10)
55    with capture:
56        assert m.runExampleVirt(ex12p, 20) == 32
57    assert (
58        capture
59        == """
60        ExtendedExampleVirt::run(20), calling parent..
61        Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2)
62    """  # noqa: E501 line too long
63    )
64    with capture:
65        assert m.runExampleVirtBool(ex12p) is False
66    assert capture == "ExtendedExampleVirt::run_bool()"
67    with capture:
68        m.runExampleVirtVirtual(ex12p)
69    assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world"
70
71    ex12p2 = ExtendedExampleVirt2(15)
72    with capture:
73        assert m.runExampleVirt(ex12p2, 50) == 68
74    assert (
75        capture
76        == """
77        ExtendedExampleVirt::run(50), calling parent..
78        Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2)
79    """  # noqa: E501 line too long
80    )
81
82    cstats = ConstructorStats.get(m.ExampleVirt)
83    assert cstats.alive() == 3
84    del ex12, ex12p, ex12p2
85    assert cstats.alive() == 0
86    assert cstats.values() == ["10", "11", "17"]
87    assert cstats.copy_constructions == 0
88    assert cstats.move_constructions >= 0
89
90
91def test_alias_delay_initialization1(capture):
92    """`A` only initializes its trampoline class when we inherit from it
93
94    If we just create and use an A instance directly, the trampoline initialization is
95    bypassed and we only initialize an A() instead (for performance reasons).
96    """
97
98    class B(m.A):
99        def __init__(self):
100            super(B, self).__init__()
101
102        def f(self):
103            print("In python f()")
104
105    # C++ version
106    with capture:
107        a = m.A()
108        m.call_f(a)
109        del a
110        pytest.gc_collect()
111    assert capture == "A.f()"
112
113    # Python version
114    with capture:
115        b = B()
116        m.call_f(b)
117        del b
118        pytest.gc_collect()
119    assert (
120        capture
121        == """
122        PyA.PyA()
123        PyA.f()
124        In python f()
125        PyA.~PyA()
126    """
127    )
128
129
130def test_alias_delay_initialization2(capture):
131    """`A2`, unlike the above, is configured to always initialize the alias
132
133    While the extra initialization and extra class layer has small virtual dispatch
134    performance penalty, it also allows us to do more things with the trampoline
135    class such as defining local variables and performing construction/destruction.
136    """
137
138    class B2(m.A2):
139        def __init__(self):
140            super(B2, self).__init__()
141
142        def f(self):
143            print("In python B2.f()")
144
145    # No python subclass version
146    with capture:
147        a2 = m.A2()
148        m.call_f(a2)
149        del a2
150        pytest.gc_collect()
151        a3 = m.A2(1)
152        m.call_f(a3)
153        del a3
154        pytest.gc_collect()
155    assert (
156        capture
157        == """
158        PyA2.PyA2()
159        PyA2.f()
160        A2.f()
161        PyA2.~PyA2()
162        PyA2.PyA2()
163        PyA2.f()
164        A2.f()
165        PyA2.~PyA2()
166    """
167    )
168
169    # Python subclass version
170    with capture:
171        b2 = B2()
172        m.call_f(b2)
173        del b2
174        pytest.gc_collect()
175    assert (
176        capture
177        == """
178        PyA2.PyA2()
179        PyA2.f()
180        In python B2.f()
181        PyA2.~PyA2()
182    """
183    )
184
185
186# PyPy: Reference count > 1 causes call with noncopyable instance
187# to fail in ncv1.print_nc()
188@pytest.mark.xfail("env.PYPY")
189@pytest.mark.skipif(
190    not hasattr(m, "NCVirt"), reason="NCVirt does not work on Intel/PGI/NVCC compilers"
191)
192def test_move_support():
193    class NCVirtExt(m.NCVirt):
194        def get_noncopyable(self, a, b):
195            # Constructs and returns a new instance:
196            nc = m.NonCopyable(a * a, b * b)
197            return nc
198
199        def get_movable(self, a, b):
200            # Return a referenced copy
201            self.movable = m.Movable(a, b)
202            return self.movable
203
204    class NCVirtExt2(m.NCVirt):
205        def get_noncopyable(self, a, b):
206            # Keep a reference: this is going to throw an exception
207            self.nc = m.NonCopyable(a, b)
208            return self.nc
209
210        def get_movable(self, a, b):
211            # Return a new instance without storing it
212            return m.Movable(a, b)
213
214    ncv1 = NCVirtExt()
215    assert ncv1.print_nc(2, 3) == "36"
216    assert ncv1.print_movable(4, 5) == "9"
217    ncv2 = NCVirtExt2()
218    assert ncv2.print_movable(7, 7) == "14"
219    # Don't check the exception message here because it differs under debug/non-debug mode
220    with pytest.raises(RuntimeError):
221        ncv2.print_nc(9, 9)
222
223    nc_stats = ConstructorStats.get(m.NonCopyable)
224    mv_stats = ConstructorStats.get(m.Movable)
225    assert nc_stats.alive() == 1
226    assert mv_stats.alive() == 1
227    del ncv1, ncv2
228    assert nc_stats.alive() == 0
229    assert mv_stats.alive() == 0
230    assert nc_stats.values() == ["4", "9", "9", "9"]
231    assert mv_stats.values() == ["4", "5", "7", "7"]
232    assert nc_stats.copy_constructions == 0
233    assert mv_stats.copy_constructions == 1
234    assert nc_stats.move_constructions >= 0
235    assert mv_stats.move_constructions >= 0
236
237
238def test_dispatch_issue(msg):
239    """#159: virtual function dispatch has problems with similar-named functions"""
240
241    class PyClass1(m.DispatchIssue):
242        def dispatch(self):
243            return "Yay.."
244
245    class PyClass2(m.DispatchIssue):
246        def dispatch(self):
247            with pytest.raises(RuntimeError) as excinfo:
248                super(PyClass2, self).dispatch()
249            assert (
250                msg(excinfo.value)
251                == 'Tried to call pure virtual function "Base::dispatch"'
252            )
253
254            return m.dispatch_issue_go(PyClass1())
255
256    b = PyClass2()
257    assert m.dispatch_issue_go(b) == "Yay.."
258
259
260def test_override_ref():
261    """#392/397: overriding reference-returning functions"""
262    o = m.OverrideTest("asdf")
263
264    # Not allowed (see associated .cpp comment)
265    # i = o.str_ref()
266    # assert o.str_ref() == "asdf"
267    assert o.str_value() == "asdf"
268
269    assert o.A_value().value == "hi"
270    a = o.A_ref()
271    assert a.value == "hi"
272    a.value = "bye"
273    assert a.value == "bye"
274
275
276def test_inherited_virtuals():
277    class AR(m.A_Repeat):
278        def unlucky_number(self):
279            return 99
280
281    class AT(m.A_Tpl):
282        def unlucky_number(self):
283            return 999
284
285    obj = AR()
286    assert obj.say_something(3) == "hihihi"
287    assert obj.unlucky_number() == 99
288    assert obj.say_everything() == "hi 99"
289
290    obj = AT()
291    assert obj.say_something(3) == "hihihi"
292    assert obj.unlucky_number() == 999
293    assert obj.say_everything() == "hi 999"
294
295    for obj in [m.B_Repeat(), m.B_Tpl()]:
296        assert obj.say_something(3) == "B says hi 3 times"
297        assert obj.unlucky_number() == 13
298        assert obj.lucky_number() == 7.0
299        assert obj.say_everything() == "B says hi 1 times 13"
300
301    for obj in [m.C_Repeat(), m.C_Tpl()]:
302        assert obj.say_something(3) == "B says hi 3 times"
303        assert obj.unlucky_number() == 4444
304        assert obj.lucky_number() == 888.0
305        assert obj.say_everything() == "B says hi 1 times 4444"
306
307    class CR(m.C_Repeat):
308        def lucky_number(self):
309            return m.C_Repeat.lucky_number(self) + 1.25
310
311    obj = CR()
312    assert obj.say_something(3) == "B says hi 3 times"
313    assert obj.unlucky_number() == 4444
314    assert obj.lucky_number() == 889.25
315    assert obj.say_everything() == "B says hi 1 times 4444"
316
317    class CT(m.C_Tpl):
318        pass
319
320    obj = CT()
321    assert obj.say_something(3) == "B says hi 3 times"
322    assert obj.unlucky_number() == 4444
323    assert obj.lucky_number() == 888.0
324    assert obj.say_everything() == "B says hi 1 times 4444"
325
326    class CCR(CR):
327        def lucky_number(self):
328            return CR.lucky_number(self) * 10
329
330    obj = CCR()
331    assert obj.say_something(3) == "B says hi 3 times"
332    assert obj.unlucky_number() == 4444
333    assert obj.lucky_number() == 8892.5
334    assert obj.say_everything() == "B says hi 1 times 4444"
335
336    class CCT(CT):
337        def lucky_number(self):
338            return CT.lucky_number(self) * 1000
339
340    obj = CCT()
341    assert obj.say_something(3) == "B says hi 3 times"
342    assert obj.unlucky_number() == 4444
343    assert obj.lucky_number() == 888000.0
344    assert obj.say_everything() == "B says hi 1 times 4444"
345
346    class DR(m.D_Repeat):
347        def unlucky_number(self):
348            return 123
349
350        def lucky_number(self):
351            return 42.0
352
353    for obj in [m.D_Repeat(), m.D_Tpl()]:
354        assert obj.say_something(3) == "B says hi 3 times"
355        assert obj.unlucky_number() == 4444
356        assert obj.lucky_number() == 888.0
357        assert obj.say_everything() == "B says hi 1 times 4444"
358
359    obj = DR()
360    assert obj.say_something(3) == "B says hi 3 times"
361    assert obj.unlucky_number() == 123
362    assert obj.lucky_number() == 42.0
363    assert obj.say_everything() == "B says hi 1 times 123"
364
365    class DT(m.D_Tpl):
366        def say_something(self, times):
367            return "DT says:" + (" quack" * times)
368
369        def unlucky_number(self):
370            return 1234
371
372        def lucky_number(self):
373            return -4.25
374
375    obj = DT()
376    assert obj.say_something(3) == "DT says: quack quack quack"
377    assert obj.unlucky_number() == 1234
378    assert obj.lucky_number() == -4.25
379    assert obj.say_everything() == "DT says: quack 1234"
380
381    class DT2(DT):
382        def say_something(self, times):
383            return "DT2: " + ("QUACK" * times)
384
385        def unlucky_number(self):
386            return -3
387
388    class BT(m.B_Tpl):
389        def say_something(self, times):
390            return "BT" * times
391
392        def unlucky_number(self):
393            return -7
394
395        def lucky_number(self):
396            return -1.375
397
398    obj = BT()
399    assert obj.say_something(3) == "BTBTBT"
400    assert obj.unlucky_number() == -7
401    assert obj.lucky_number() == -1.375
402    assert obj.say_everything() == "BT -7"
403
404
405def test_issue_1454():
406    # Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7)
407    m.test_gil()
408    m.test_gil_from_thread()
409