• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2from __future__ import division
3import pytest
4import sys
5
6import env  # noqa: F401
7
8from pybind11_tests import pytypes as m
9from pybind11_tests import debug_enabled
10
11
12def test_int(doc):
13    assert doc(m.get_int) == "get_int() -> int"
14
15
16def test_iterator(doc):
17    assert doc(m.get_iterator) == "get_iterator() -> Iterator"
18
19
20def test_iterable(doc):
21    assert doc(m.get_iterable) == "get_iterable() -> Iterable"
22
23
24def test_list(capture, doc):
25    with capture:
26        lst = m.get_list()
27        assert lst == ["inserted-0", "overwritten", "inserted-2"]
28
29        lst.append("value2")
30        m.print_list(lst)
31    assert (
32        capture.unordered
33        == """
34        Entry at position 0: value
35        list item 0: inserted-0
36        list item 1: overwritten
37        list item 2: inserted-2
38        list item 3: value2
39    """
40    )
41
42    assert doc(m.get_list) == "get_list() -> list"
43    assert doc(m.print_list) == "print_list(arg0: list) -> None"
44
45
46def test_none(capture, doc):
47    assert doc(m.get_none) == "get_none() -> None"
48    assert doc(m.print_none) == "print_none(arg0: None) -> None"
49
50
51def test_set(capture, doc):
52    s = m.get_set()
53    assert s == {"key1", "key2", "key3"}
54
55    with capture:
56        s.add("key4")
57        m.print_set(s)
58    assert (
59        capture.unordered
60        == """
61        key: key1
62        key: key2
63        key: key3
64        key: key4
65    """
66    )
67
68    assert not m.set_contains(set([]), 42)
69    assert m.set_contains({42}, 42)
70    assert m.set_contains({"foo"}, "foo")
71
72    assert doc(m.get_list) == "get_list() -> list"
73    assert doc(m.print_list) == "print_list(arg0: list) -> None"
74
75
76def test_dict(capture, doc):
77    d = m.get_dict()
78    assert d == {"key": "value"}
79
80    with capture:
81        d["key2"] = "value2"
82        m.print_dict(d)
83    assert (
84        capture.unordered
85        == """
86        key: key, value=value
87        key: key2, value=value2
88    """
89    )
90
91    assert not m.dict_contains({}, 42)
92    assert m.dict_contains({42: None}, 42)
93    assert m.dict_contains({"foo": None}, "foo")
94
95    assert doc(m.get_dict) == "get_dict() -> dict"
96    assert doc(m.print_dict) == "print_dict(arg0: dict) -> None"
97
98    assert m.dict_keyword_constructor() == {"x": 1, "y": 2, "z": 3}
99
100
101def test_str(doc):
102    assert m.str_from_string().encode().decode() == "baz"
103    assert m.str_from_bytes().encode().decode() == "boo"
104
105    assert doc(m.str_from_bytes) == "str_from_bytes() -> str"
106
107    class A(object):
108        def __str__(self):
109            return "this is a str"
110
111        def __repr__(self):
112            return "this is a repr"
113
114    assert m.str_from_object(A()) == "this is a str"
115    assert m.repr_from_object(A()) == "this is a repr"
116    assert m.str_from_handle(A()) == "this is a str"
117
118    s1, s2 = m.str_format()
119    assert s1 == "1 + 2 = 3"
120    assert s1 == s2
121
122    malformed_utf8 = b"\x80"
123    assert m.str_from_object(malformed_utf8) is malformed_utf8  # To be fixed; see #2380
124    if env.PY2:
125        # with pytest.raises(UnicodeDecodeError):
126        #     m.str_from_object(malformed_utf8)
127        with pytest.raises(UnicodeDecodeError):
128            m.str_from_handle(malformed_utf8)
129    else:
130        # assert m.str_from_object(malformed_utf8) == "b'\\x80'"
131        assert m.str_from_handle(malformed_utf8) == "b'\\x80'"
132
133
134def test_bytes(doc):
135    assert m.bytes_from_string().decode() == "foo"
136    assert m.bytes_from_str().decode() == "bar"
137
138    assert doc(m.bytes_from_str) == "bytes_from_str() -> {}".format(
139        "str" if env.PY2 else "bytes"
140    )
141
142
143def test_capsule(capture):
144    pytest.gc_collect()
145    with capture:
146        a = m.return_capsule_with_destructor()
147        del a
148        pytest.gc_collect()
149    assert (
150        capture.unordered
151        == """
152        creating capsule
153        destructing capsule
154    """
155    )
156
157    with capture:
158        a = m.return_capsule_with_destructor_2()
159        del a
160        pytest.gc_collect()
161    assert (
162        capture.unordered
163        == """
164        creating capsule
165        destructing capsule: 1234
166    """
167    )
168
169    with capture:
170        a = m.return_capsule_with_name_and_destructor()
171        del a
172        pytest.gc_collect()
173    assert (
174        capture.unordered
175        == """
176        created capsule (1234, 'pointer type description')
177        destructing capsule (1234, 'pointer type description')
178    """
179    )
180
181
182def test_accessors():
183    class SubTestObject:
184        attr_obj = 1
185        attr_char = 2
186
187    class TestObject:
188        basic_attr = 1
189        begin_end = [1, 2, 3]
190        d = {"operator[object]": 1, "operator[char *]": 2}
191        sub = SubTestObject()
192
193        def func(self, x, *args):
194            return self.basic_attr + x + sum(args)
195
196    d = m.accessor_api(TestObject())
197    assert d["basic_attr"] == 1
198    assert d["begin_end"] == [1, 2, 3]
199    assert d["operator[object]"] == 1
200    assert d["operator[char *]"] == 2
201    assert d["attr(object)"] == 1
202    assert d["attr(char *)"] == 2
203    assert d["missing_attr_ptr"] == "raised"
204    assert d["missing_attr_chain"] == "raised"
205    assert d["is_none"] is False
206    assert d["operator()"] == 2
207    assert d["operator*"] == 7
208    assert d["implicit_list"] == [1, 2, 3]
209    assert all(x in TestObject.__dict__ for x in d["implicit_dict"])
210
211    assert m.tuple_accessor(tuple()) == (0, 1, 2)
212
213    d = m.accessor_assignment()
214    assert d["get"] == 0
215    assert d["deferred_get"] == 0
216    assert d["set"] == 1
217    assert d["deferred_set"] == 1
218    assert d["var"] == 99
219
220
221def test_constructors():
222    """C++ default and converting constructors are equivalent to type calls in Python"""
223    types = [bytes, str, bool, int, float, tuple, list, dict, set]
224    expected = {t.__name__: t() for t in types}
225    if env.PY2:
226        # Note that bytes.__name__ == 'str' in Python 2.
227        # pybind11::str is unicode even under Python 2.
228        expected["bytes"] = bytes()
229        expected["str"] = unicode()  # noqa: F821
230    assert m.default_constructors() == expected
231
232    data = {
233        bytes: b"41",  # Currently no supported or working conversions.
234        str: 42,
235        bool: "Not empty",
236        int: "42",
237        float: "+1e3",
238        tuple: range(3),
239        list: range(3),
240        dict: [("two", 2), ("one", 1), ("three", 3)],
241        set: [4, 4, 5, 6, 6, 6],
242        memoryview: b"abc",
243    }
244    inputs = {k.__name__: v for k, v in data.items()}
245    expected = {k.__name__: k(v) for k, v in data.items()}
246    if env.PY2:  # Similar to the above. See comments above.
247        inputs["bytes"] = b"41"
248        inputs["str"] = 42
249        expected["bytes"] = b"41"
250        expected["str"] = u"42"
251
252    assert m.converting_constructors(inputs) == expected
253    assert m.cast_functions(inputs) == expected
254
255    # Converting constructors and cast functions should just reference rather
256    # than copy when no conversion is needed:
257    noconv1 = m.converting_constructors(expected)
258    for k in noconv1:
259        assert noconv1[k] is expected[k]
260
261    noconv2 = m.cast_functions(expected)
262    for k in noconv2:
263        assert noconv2[k] is expected[k]
264
265
266def test_non_converting_constructors():
267    non_converting_test_cases = [
268        ("bytes", range(10)),
269        ("none", 42),
270        ("ellipsis", 42),
271        ("type", 42),
272    ]
273    for t, v in non_converting_test_cases:
274        for move in [True, False]:
275            with pytest.raises(TypeError) as excinfo:
276                m.nonconverting_constructor(t, v, move)
277            expected_error = "Object of type '{}' is not an instance of '{}'".format(
278                type(v).__name__, t
279            )
280            assert str(excinfo.value) == expected_error
281
282
283def test_pybind11_str_raw_str():
284    # specifically to exercise pybind11::str::raw_str
285    cvt = m.convert_to_pybind11_str
286    assert cvt(u"Str") == u"Str"
287    assert cvt(b"Bytes") == u"Bytes" if env.PY2 else "b'Bytes'"
288    assert cvt(None) == u"None"
289    assert cvt(False) == u"False"
290    assert cvt(True) == u"True"
291    assert cvt(42) == u"42"
292    assert cvt(2 ** 65) == u"36893488147419103232"
293    assert cvt(-1.50) == u"-1.5"
294    assert cvt(()) == u"()"
295    assert cvt((18,)) == u"(18,)"
296    assert cvt([]) == u"[]"
297    assert cvt([28]) == u"[28]"
298    assert cvt({}) == u"{}"
299    assert cvt({3: 4}) == u"{3: 4}"
300    assert cvt(set()) == u"set([])" if env.PY2 else "set()"
301    assert cvt({3, 3}) == u"set([3])" if env.PY2 else "{3}"
302
303    valid_orig = u"DZ"
304    valid_utf8 = valid_orig.encode("utf-8")
305    valid_cvt = cvt(valid_utf8)
306    assert type(valid_cvt) == bytes  # Probably surprising.
307    assert valid_cvt == b"\xc7\xb1"
308
309    malformed_utf8 = b"\x80"
310    malformed_cvt = cvt(malformed_utf8)
311    assert type(malformed_cvt) == bytes  # Probably surprising.
312    assert malformed_cvt == b"\x80"
313
314
315def test_implicit_casting():
316    """Tests implicit casting when assigning or appending to dicts and lists."""
317    z = m.get_implicit_casting()
318    assert z["d"] == {
319        "char*_i1": "abc",
320        "char*_i2": "abc",
321        "char*_e": "abc",
322        "char*_p": "abc",
323        "str_i1": "str",
324        "str_i2": "str1",
325        "str_e": "str2",
326        "str_p": "str3",
327        "int_i1": 42,
328        "int_i2": 42,
329        "int_e": 43,
330        "int_p": 44,
331    }
332    assert z["l"] == [3, 6, 9, 12, 15]
333
334
335def test_print(capture):
336    with capture:
337        m.print_function()
338    assert (
339        capture
340        == """
341        Hello, World!
342        1 2.0 three True -- multiple args
343        *args-and-a-custom-separator
344        no new line here -- next print
345        flush
346        py::print + str.format = this
347    """
348    )
349    assert capture.stderr == "this goes to stderr"
350
351    with pytest.raises(RuntimeError) as excinfo:
352        m.print_failure()
353    assert str(excinfo.value) == "make_tuple(): unable to convert " + (
354        "argument of type 'UnregisteredType' to Python object"
355        if debug_enabled
356        else "arguments to Python object (compile in debug mode for details)"
357    )
358
359
360def test_hash():
361    class Hashable(object):
362        def __init__(self, value):
363            self.value = value
364
365        def __hash__(self):
366            return self.value
367
368    class Unhashable(object):
369        __hash__ = None
370
371    assert m.hash_function(Hashable(42)) == 42
372    with pytest.raises(TypeError):
373        m.hash_function(Unhashable())
374
375
376def test_number_protocol():
377    for a, b in [(1, 1), (3, 5)]:
378        li = [
379            a == b,
380            a != b,
381            a < b,
382            a <= b,
383            a > b,
384            a >= b,
385            a + b,
386            a - b,
387            a * b,
388            a / b,
389            a | b,
390            a & b,
391            a ^ b,
392            a >> b,
393            a << b,
394        ]
395        assert m.test_number_protocol(a, b) == li
396
397
398def test_list_slicing():
399    li = list(range(100))
400    assert li[::2] == m.test_list_slicing(li)
401
402
403def test_issue2361():
404    # See issue #2361
405    assert m.issue2361_str_implicit_copy_none() == "None"
406    with pytest.raises(TypeError) as excinfo:
407        assert m.issue2361_dict_implicit_copy_none()
408    assert "'NoneType' object is not iterable" in str(excinfo.value)
409
410
411@pytest.mark.parametrize(
412    "method, args, fmt, expected_view",
413    [
414        (m.test_memoryview_object, (b"red",), "B", b"red"),
415        (m.test_memoryview_buffer_info, (b"green",), "B", b"green"),
416        (m.test_memoryview_from_buffer, (False,), "h", [3, 1, 4, 1, 5]),
417        (m.test_memoryview_from_buffer, (True,), "H", [2, 7, 1, 8]),
418        (m.test_memoryview_from_buffer_nativeformat, (), "@i", [4, 7, 5]),
419    ],
420)
421def test_memoryview(method, args, fmt, expected_view):
422    view = method(*args)
423    assert isinstance(view, memoryview)
424    assert view.format == fmt
425    if isinstance(expected_view, bytes) or not env.PY2:
426        view_as_list = list(view)
427    else:
428        # Using max to pick non-zero byte (big-endian vs little-endian).
429        view_as_list = [max([ord(c) for c in s]) for s in view]
430    assert view_as_list == list(expected_view)
431
432
433@pytest.mark.xfail("env.PYPY", reason="getrefcount is not available")
434@pytest.mark.parametrize(
435    "method",
436    [
437        m.test_memoryview_object,
438        m.test_memoryview_buffer_info,
439    ],
440)
441def test_memoryview_refcount(method):
442    buf = b"\x0a\x0b\x0c\x0d"
443    ref_before = sys.getrefcount(buf)
444    view = method(buf)
445    ref_after = sys.getrefcount(buf)
446    assert ref_before < ref_after
447    assert list(view) == list(buf)
448
449
450def test_memoryview_from_buffer_empty_shape():
451    view = m.test_memoryview_from_buffer_empty_shape()
452    assert isinstance(view, memoryview)
453    assert view.format == "B"
454    if env.PY2:
455        # Python 2 behavior is weird, but Python 3 (the future) is fine.
456        # PyPy3 has <memoryview, while CPython 2 has <memory
457        assert bytes(view).startswith(b"<memory")
458    else:
459        assert bytes(view) == b""
460
461
462def test_test_memoryview_from_buffer_invalid_strides():
463    with pytest.raises(RuntimeError):
464        m.test_memoryview_from_buffer_invalid_strides()
465
466
467def test_test_memoryview_from_buffer_nullptr():
468    if env.PY2:
469        m.test_memoryview_from_buffer_nullptr()
470    else:
471        with pytest.raises(ValueError):
472            m.test_memoryview_from_buffer_nullptr()
473
474
475@pytest.mark.skipif("env.PY2")
476def test_memoryview_from_memory():
477    view = m.test_memoryview_from_memory()
478    assert isinstance(view, memoryview)
479    assert view.format == "B"
480    assert bytes(view) == b"\xff\xe1\xab\x37"
481
482
483def test_builtin_functions():
484    assert m.get_len([i for i in range(42)]) == 42
485    with pytest.raises(TypeError) as exc_info:
486        m.get_len(i for i in range(42))
487    assert str(exc_info.value) in [
488        "object of type 'generator' has no len()",
489        "'generator' has no length",
490    ]  # PyPy
491