• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2import pytest
3
4import env  # noqa: F401
5
6from pybind11_tests import stl_binders as m
7
8
9def test_vector_int():
10    v_int = m.VectorInt([0, 0])
11    assert len(v_int) == 2
12    assert bool(v_int) is True
13
14    # test construction from a generator
15    v_int1 = m.VectorInt(x for x in range(5))
16    assert v_int1 == m.VectorInt([0, 1, 2, 3, 4])
17
18    v_int2 = m.VectorInt([0, 0])
19    assert v_int == v_int2
20    v_int2[1] = 1
21    assert v_int != v_int2
22
23    v_int2.append(2)
24    v_int2.insert(0, 1)
25    v_int2.insert(0, 2)
26    v_int2.insert(0, 3)
27    v_int2.insert(6, 3)
28    assert str(v_int2) == "VectorInt[3, 2, 1, 0, 1, 2, 3]"
29    with pytest.raises(IndexError):
30        v_int2.insert(8, 4)
31
32    v_int.append(99)
33    v_int2[2:-2] = v_int
34    assert v_int2 == m.VectorInt([3, 2, 0, 0, 99, 2, 3])
35    del v_int2[1:3]
36    assert v_int2 == m.VectorInt([3, 0, 99, 2, 3])
37    del v_int2[0]
38    assert v_int2 == m.VectorInt([0, 99, 2, 3])
39
40    v_int2.extend(m.VectorInt([4, 5]))
41    assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5])
42
43    v_int2.extend([6, 7])
44    assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7])
45
46    # test error handling, and that the vector is unchanged
47    with pytest.raises(RuntimeError):
48        v_int2.extend([8, "a"])
49
50    assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7])
51
52    # test extending from a generator
53    v_int2.extend(x for x in range(5))
54    assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4])
55
56    # test negative indexing
57    assert v_int2[-1] == 4
58
59    # insert with negative index
60    v_int2.insert(-1, 88)
61    assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88, 4])
62
63    # delete negative index
64    del v_int2[-1]
65    assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88])
66
67    v_int2.clear()
68    assert len(v_int2) == 0
69
70
71# Older PyPy's failed here, related to the PyPy's buffer protocol.
72def test_vector_buffer():
73    b = bytearray([1, 2, 3, 4])
74    v = m.VectorUChar(b)
75    assert v[1] == 2
76    v[2] = 5
77    mv = memoryview(v)  # We expose the buffer interface
78    if not env.PY2:
79        assert mv[2] == 5
80        mv[2] = 6
81    else:
82        assert mv[2] == "\x05"
83        mv[2] = "\x06"
84    assert v[2] == 6
85
86    if not env.PY2:
87        mv = memoryview(b)
88        v = m.VectorUChar(mv[::2])
89        assert v[1] == 3
90
91    with pytest.raises(RuntimeError) as excinfo:
92        m.create_undeclstruct()  # Undeclared struct contents, no buffer interface
93    assert "NumPy type info missing for " in str(excinfo.value)
94
95
96def test_vector_buffer_numpy():
97    np = pytest.importorskip("numpy")
98    a = np.array([1, 2, 3, 4], dtype=np.int32)
99    with pytest.raises(TypeError):
100        m.VectorInt(a)
101
102    a = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.uintc)
103    v = m.VectorInt(a[0, :])
104    assert len(v) == 4
105    assert v[2] == 3
106    ma = np.asarray(v)
107    ma[2] = 5
108    assert v[2] == 5
109
110    v = m.VectorInt(a[:, 1])
111    assert len(v) == 3
112    assert v[2] == 10
113
114    v = m.get_vectorstruct()
115    assert v[0].x == 5
116    ma = np.asarray(v)
117    ma[1]["x"] = 99
118    assert v[1].x == 99
119
120    v = m.VectorStruct(
121        np.zeros(
122            3,
123            dtype=np.dtype(
124                [("w", "bool"), ("x", "I"), ("y", "float64"), ("z", "bool")], align=True
125            ),
126        )
127    )
128    assert len(v) == 3
129
130    b = np.array([1, 2, 3, 4], dtype=np.uint8)
131    v = m.VectorUChar(b[::2])
132    assert v[1] == 3
133
134
135def test_vector_bool():
136    import pybind11_cross_module_tests as cm
137
138    vv_c = cm.VectorBool()
139    for i in range(10):
140        vv_c.append(i % 2 == 0)
141    for i in range(10):
142        assert vv_c[i] == (i % 2 == 0)
143    assert str(vv_c) == "VectorBool[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]"
144
145
146def test_vector_custom():
147    v_a = m.VectorEl()
148    v_a.append(m.El(1))
149    v_a.append(m.El(2))
150    assert str(v_a) == "VectorEl[El{1}, El{2}]"
151
152    vv_a = m.VectorVectorEl()
153    vv_a.append(v_a)
154    vv_b = vv_a[0]
155    assert str(vv_b) == "VectorEl[El{1}, El{2}]"
156
157
158def test_map_string_double():
159    mm = m.MapStringDouble()
160    mm["a"] = 1
161    mm["b"] = 2.5
162
163    assert list(mm) == ["a", "b"]
164    assert list(mm.items()) == [("a", 1), ("b", 2.5)]
165    assert str(mm) == "MapStringDouble{a: 1, b: 2.5}"
166
167    um = m.UnorderedMapStringDouble()
168    um["ua"] = 1.1
169    um["ub"] = 2.6
170
171    assert sorted(list(um)) == ["ua", "ub"]
172    assert sorted(list(um.items())) == [("ua", 1.1), ("ub", 2.6)]
173    assert "UnorderedMapStringDouble" in str(um)
174
175
176def test_map_string_double_const():
177    mc = m.MapStringDoubleConst()
178    mc["a"] = 10
179    mc["b"] = 20.5
180    assert str(mc) == "MapStringDoubleConst{a: 10, b: 20.5}"
181
182    umc = m.UnorderedMapStringDoubleConst()
183    umc["a"] = 11
184    umc["b"] = 21.5
185
186    str(umc)
187
188
189def test_noncopyable_containers():
190    # std::vector
191    vnc = m.get_vnc(5)
192    for i in range(0, 5):
193        assert vnc[i].value == i + 1
194
195    for i, j in enumerate(vnc, start=1):
196        assert j.value == i
197
198    # std::deque
199    dnc = m.get_dnc(5)
200    for i in range(0, 5):
201        assert dnc[i].value == i + 1
202
203    i = 1
204    for j in dnc:
205        assert j.value == i
206        i += 1
207
208    # std::map
209    mnc = m.get_mnc(5)
210    for i in range(1, 6):
211        assert mnc[i].value == 10 * i
212
213    vsum = 0
214    for k, v in mnc.items():
215        assert v.value == 10 * k
216        vsum += v.value
217
218    assert vsum == 150
219
220    # std::unordered_map
221    mnc = m.get_umnc(5)
222    for i in range(1, 6):
223        assert mnc[i].value == 10 * i
224
225    vsum = 0
226    for k, v in mnc.items():
227        assert v.value == 10 * k
228        vsum += v.value
229
230    assert vsum == 150
231
232    # nested std::map<std::vector>
233    nvnc = m.get_nvnc(5)
234    for i in range(1, 6):
235        for j in range(0, 5):
236            assert nvnc[i][j].value == j + 1
237
238    # Note: maps do not have .values()
239    for _, v in nvnc.items():
240        for i, j in enumerate(v, start=1):
241            assert j.value == i
242
243    # nested std::map<std::map>
244    nmnc = m.get_nmnc(5)
245    for i in range(1, 6):
246        for j in range(10, 60, 10):
247            assert nmnc[i][j].value == 10 * j
248
249    vsum = 0
250    for _, v_o in nmnc.items():
251        for k_i, v_i in v_o.items():
252            assert v_i.value == 10 * k_i
253            vsum += v_i.value
254
255    assert vsum == 7500
256
257    # nested std::unordered_map<std::unordered_map>
258    numnc = m.get_numnc(5)
259    for i in range(1, 6):
260        for j in range(10, 60, 10):
261            assert numnc[i][j].value == 10 * j
262
263    vsum = 0
264    for _, v_o in numnc.items():
265        for k_i, v_i in v_o.items():
266            assert v_i.value == 10 * k_i
267            vsum += v_i.value
268
269    assert vsum == 7500
270
271
272def test_map_delitem():
273    mm = m.MapStringDouble()
274    mm["a"] = 1
275    mm["b"] = 2.5
276
277    assert list(mm) == ["a", "b"]
278    assert list(mm.items()) == [("a", 1), ("b", 2.5)]
279    del mm["a"]
280    assert list(mm) == ["b"]
281    assert list(mm.items()) == [("b", 2.5)]
282
283    um = m.UnorderedMapStringDouble()
284    um["ua"] = 1.1
285    um["ub"] = 2.6
286
287    assert sorted(list(um)) == ["ua", "ub"]
288    assert sorted(list(um.items())) == [("ua", 1.1), ("ub", 2.6)]
289    del um["ua"]
290    assert sorted(list(um)) == ["ub"]
291    assert sorted(list(um.items())) == [("ub", 2.6)]
292