• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2import pytest
3
4import env  # noqa: F401
5
6from pybind11_tests import numpy_array as m
7
8np = pytest.importorskip("numpy")
9
10
11def test_dtypes():
12    # See issue #1328.
13    # - Platform-dependent sizes.
14    for size_check in m.get_platform_dtype_size_checks():
15        print(size_check)
16        assert size_check.size_cpp == size_check.size_numpy, size_check
17    # - Concrete sizes.
18    for check in m.get_concrete_dtype_checks():
19        print(check)
20        assert check.numpy == check.pybind11, check
21        if check.numpy.num != check.pybind11.num:
22            print(
23                "NOTE: typenum mismatch for {}: {} != {}".format(
24                    check, check.numpy.num, check.pybind11.num
25                )
26            )
27
28
29@pytest.fixture(scope="function")
30def arr():
31    return np.array([[1, 2, 3], [4, 5, 6]], "=u2")
32
33
34def test_array_attributes():
35    a = np.array(0, "f8")
36    assert m.ndim(a) == 0
37    assert all(m.shape(a) == [])
38    assert all(m.strides(a) == [])
39    with pytest.raises(IndexError) as excinfo:
40        m.shape(a, 0)
41    assert str(excinfo.value) == "invalid axis: 0 (ndim = 0)"
42    with pytest.raises(IndexError) as excinfo:
43        m.strides(a, 0)
44    assert str(excinfo.value) == "invalid axis: 0 (ndim = 0)"
45    assert m.writeable(a)
46    assert m.size(a) == 1
47    assert m.itemsize(a) == 8
48    assert m.nbytes(a) == 8
49    assert m.owndata(a)
50
51    a = np.array([[1, 2, 3], [4, 5, 6]], "u2").view()
52    a.flags.writeable = False
53    assert m.ndim(a) == 2
54    assert all(m.shape(a) == [2, 3])
55    assert m.shape(a, 0) == 2
56    assert m.shape(a, 1) == 3
57    assert all(m.strides(a) == [6, 2])
58    assert m.strides(a, 0) == 6
59    assert m.strides(a, 1) == 2
60    with pytest.raises(IndexError) as excinfo:
61        m.shape(a, 2)
62    assert str(excinfo.value) == "invalid axis: 2 (ndim = 2)"
63    with pytest.raises(IndexError) as excinfo:
64        m.strides(a, 2)
65    assert str(excinfo.value) == "invalid axis: 2 (ndim = 2)"
66    assert not m.writeable(a)
67    assert m.size(a) == 6
68    assert m.itemsize(a) == 2
69    assert m.nbytes(a) == 12
70    assert not m.owndata(a)
71
72
73@pytest.mark.parametrize(
74    "args, ret", [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)]
75)
76def test_index_offset(arr, args, ret):
77    assert m.index_at(arr, *args) == ret
78    assert m.index_at_t(arr, *args) == ret
79    assert m.offset_at(arr, *args) == ret * arr.dtype.itemsize
80    assert m.offset_at_t(arr, *args) == ret * arr.dtype.itemsize
81
82
83def test_dim_check_fail(arr):
84    for func in (
85        m.index_at,
86        m.index_at_t,
87        m.offset_at,
88        m.offset_at_t,
89        m.data,
90        m.data_t,
91        m.mutate_data,
92        m.mutate_data_t,
93    ):
94        with pytest.raises(IndexError) as excinfo:
95            func(arr, 1, 2, 3)
96        assert str(excinfo.value) == "too many indices for an array: 3 (ndim = 2)"
97
98
99@pytest.mark.parametrize(
100    "args, ret",
101    [
102        ([], [1, 2, 3, 4, 5, 6]),
103        ([1], [4, 5, 6]),
104        ([0, 1], [2, 3, 4, 5, 6]),
105        ([1, 2], [6]),
106    ],
107)
108def test_data(arr, args, ret):
109    from sys import byteorder
110
111    assert all(m.data_t(arr, *args) == ret)
112    assert all(m.data(arr, *args)[(0 if byteorder == "little" else 1) :: 2] == ret)
113    assert all(m.data(arr, *args)[(1 if byteorder == "little" else 0) :: 2] == 0)
114
115
116@pytest.mark.parametrize("dim", [0, 1, 3])
117def test_at_fail(arr, dim):
118    for func in m.at_t, m.mutate_at_t:
119        with pytest.raises(IndexError) as excinfo:
120            func(arr, *([0] * dim))
121        assert str(excinfo.value) == "index dimension mismatch: {} (ndim = 2)".format(
122            dim
123        )
124
125
126def test_at(arr):
127    assert m.at_t(arr, 0, 2) == 3
128    assert m.at_t(arr, 1, 0) == 4
129
130    assert all(m.mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
131    assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
132
133
134def test_mutate_readonly(arr):
135    arr.flags.writeable = False
136    for func, args in (
137        (m.mutate_data, ()),
138        (m.mutate_data_t, ()),
139        (m.mutate_at_t, (0, 0)),
140    ):
141        with pytest.raises(ValueError) as excinfo:
142            func(arr, *args)
143        assert str(excinfo.value) == "array is not writeable"
144
145
146def test_mutate_data(arr):
147    assert all(m.mutate_data(arr).ravel() == [2, 4, 6, 8, 10, 12])
148    assert all(m.mutate_data(arr).ravel() == [4, 8, 12, 16, 20, 24])
149    assert all(m.mutate_data(arr, 1).ravel() == [4, 8, 12, 32, 40, 48])
150    assert all(m.mutate_data(arr, 0, 1).ravel() == [4, 16, 24, 64, 80, 96])
151    assert all(m.mutate_data(arr, 1, 2).ravel() == [4, 16, 24, 64, 80, 192])
152
153    assert all(m.mutate_data_t(arr).ravel() == [5, 17, 25, 65, 81, 193])
154    assert all(m.mutate_data_t(arr).ravel() == [6, 18, 26, 66, 82, 194])
155    assert all(m.mutate_data_t(arr, 1).ravel() == [6, 18, 26, 67, 83, 195])
156    assert all(m.mutate_data_t(arr, 0, 1).ravel() == [6, 19, 27, 68, 84, 196])
157    assert all(m.mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])
158
159
160def test_bounds_check(arr):
161    for func in (
162        m.index_at,
163        m.index_at_t,
164        m.data,
165        m.data_t,
166        m.mutate_data,
167        m.mutate_data_t,
168        m.at_t,
169        m.mutate_at_t,
170    ):
171        with pytest.raises(IndexError) as excinfo:
172            func(arr, 2, 0)
173        assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2"
174        with pytest.raises(IndexError) as excinfo:
175            func(arr, 0, 4)
176        assert str(excinfo.value) == "index 4 is out of bounds for axis 1 with size 3"
177
178
179def test_make_c_f_array():
180    assert m.make_c_array().flags.c_contiguous
181    assert not m.make_c_array().flags.f_contiguous
182    assert m.make_f_array().flags.f_contiguous
183    assert not m.make_f_array().flags.c_contiguous
184
185
186def test_make_empty_shaped_array():
187    m.make_empty_shaped_array()
188
189    # empty shape means numpy scalar, PEP 3118
190    assert m.scalar_int().ndim == 0
191    assert m.scalar_int().shape == ()
192    assert m.scalar_int() == 42
193
194
195def test_wrap():
196    def assert_references(a, b, base=None):
197        from distutils.version import LooseVersion
198
199        if base is None:
200            base = a
201        assert a is not b
202        assert a.__array_interface__["data"][0] == b.__array_interface__["data"][0]
203        assert a.shape == b.shape
204        assert a.strides == b.strides
205        assert a.flags.c_contiguous == b.flags.c_contiguous
206        assert a.flags.f_contiguous == b.flags.f_contiguous
207        assert a.flags.writeable == b.flags.writeable
208        assert a.flags.aligned == b.flags.aligned
209        if LooseVersion(np.__version__) >= LooseVersion("1.14.0"):
210            assert a.flags.writebackifcopy == b.flags.writebackifcopy
211        else:
212            assert a.flags.updateifcopy == b.flags.updateifcopy
213        assert np.all(a == b)
214        assert not b.flags.owndata
215        assert b.base is base
216        if a.flags.writeable and a.ndim == 2:
217            a[0, 0] = 1234
218            assert b[0, 0] == 1234
219
220    a1 = np.array([1, 2], dtype=np.int16)
221    assert a1.flags.owndata and a1.base is None
222    a2 = m.wrap(a1)
223    assert_references(a1, a2)
224
225    a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order="F")
226    assert a1.flags.owndata and a1.base is None
227    a2 = m.wrap(a1)
228    assert_references(a1, a2)
229
230    a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order="C")
231    a1.flags.writeable = False
232    a2 = m.wrap(a1)
233    assert_references(a1, a2)
234
235    a1 = np.random.random((4, 4, 4))
236    a2 = m.wrap(a1)
237    assert_references(a1, a2)
238
239    a1t = a1.transpose()
240    a2 = m.wrap(a1t)
241    assert_references(a1t, a2, a1)
242
243    a1d = a1.diagonal()
244    a2 = m.wrap(a1d)
245    assert_references(a1d, a2, a1)
246
247    a1m = a1[::-1, ::-1, ::-1]
248    a2 = m.wrap(a1m)
249    assert_references(a1m, a2, a1)
250
251
252def test_numpy_view(capture):
253    with capture:
254        ac = m.ArrayClass()
255        ac_view_1 = ac.numpy_view()
256        ac_view_2 = ac.numpy_view()
257        assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32))
258        del ac
259        pytest.gc_collect()
260    assert (
261        capture
262        == """
263        ArrayClass()
264        ArrayClass::numpy_view()
265        ArrayClass::numpy_view()
266    """
267    )
268    ac_view_1[0] = 4
269    ac_view_1[1] = 3
270    assert ac_view_2[0] == 4
271    assert ac_view_2[1] == 3
272    with capture:
273        del ac_view_1
274        del ac_view_2
275        pytest.gc_collect()
276        pytest.gc_collect()
277    assert (
278        capture
279        == """
280        ~ArrayClass()
281    """
282    )
283
284
285def test_cast_numpy_int64_to_uint64():
286    m.function_taking_uint64(123)
287    m.function_taking_uint64(np.uint64(123))
288
289
290def test_isinstance():
291    assert m.isinstance_untyped(np.array([1, 2, 3]), "not an array")
292    assert m.isinstance_typed(np.array([1.0, 2.0, 3.0]))
293
294
295def test_constructors():
296    defaults = m.default_constructors()
297    for a in defaults.values():
298        assert a.size == 0
299    assert defaults["array"].dtype == np.array([]).dtype
300    assert defaults["array_t<int32>"].dtype == np.int32
301    assert defaults["array_t<double>"].dtype == np.float64
302
303    results = m.converting_constructors([1, 2, 3])
304    for a in results.values():
305        np.testing.assert_array_equal(a, [1, 2, 3])
306    assert results["array"].dtype == np.int_
307    assert results["array_t<int32>"].dtype == np.int32
308    assert results["array_t<double>"].dtype == np.float64
309
310
311def test_overload_resolution(msg):
312    # Exact overload matches:
313    assert m.overloaded(np.array([1], dtype="float64")) == "double"
314    assert m.overloaded(np.array([1], dtype="float32")) == "float"
315    assert m.overloaded(np.array([1], dtype="ushort")) == "unsigned short"
316    assert m.overloaded(np.array([1], dtype="intc")) == "int"
317    assert m.overloaded(np.array([1], dtype="longlong")) == "long long"
318    assert m.overloaded(np.array([1], dtype="complex")) == "double complex"
319    assert m.overloaded(np.array([1], dtype="csingle")) == "float complex"
320
321    # No exact match, should call first convertible version:
322    assert m.overloaded(np.array([1], dtype="uint8")) == "double"
323
324    with pytest.raises(TypeError) as excinfo:
325        m.overloaded("not an array")
326    assert (
327        msg(excinfo.value)
328        == """
329        overloaded(): incompatible function arguments. The following argument types are supported:
330            1. (arg0: numpy.ndarray[numpy.float64]) -> str
331            2. (arg0: numpy.ndarray[numpy.float32]) -> str
332            3. (arg0: numpy.ndarray[numpy.int32]) -> str
333            4. (arg0: numpy.ndarray[numpy.uint16]) -> str
334            5. (arg0: numpy.ndarray[numpy.int64]) -> str
335            6. (arg0: numpy.ndarray[numpy.complex128]) -> str
336            7. (arg0: numpy.ndarray[numpy.complex64]) -> str
337
338        Invoked with: 'not an array'
339    """
340    )
341
342    assert m.overloaded2(np.array([1], dtype="float64")) == "double"
343    assert m.overloaded2(np.array([1], dtype="float32")) == "float"
344    assert m.overloaded2(np.array([1], dtype="complex64")) == "float complex"
345    assert m.overloaded2(np.array([1], dtype="complex128")) == "double complex"
346    assert m.overloaded2(np.array([1], dtype="float32")) == "float"
347
348    assert m.overloaded3(np.array([1], dtype="float64")) == "double"
349    assert m.overloaded3(np.array([1], dtype="intc")) == "int"
350    expected_exc = """
351        overloaded3(): incompatible function arguments. The following argument types are supported:
352            1. (arg0: numpy.ndarray[numpy.int32]) -> str
353            2. (arg0: numpy.ndarray[numpy.float64]) -> str
354
355        Invoked with: """
356
357    with pytest.raises(TypeError) as excinfo:
358        m.overloaded3(np.array([1], dtype="uintc"))
359    assert msg(excinfo.value) == expected_exc + repr(np.array([1], dtype="uint32"))
360    with pytest.raises(TypeError) as excinfo:
361        m.overloaded3(np.array([1], dtype="float32"))
362    assert msg(excinfo.value) == expected_exc + repr(np.array([1.0], dtype="float32"))
363    with pytest.raises(TypeError) as excinfo:
364        m.overloaded3(np.array([1], dtype="complex"))
365    assert msg(excinfo.value) == expected_exc + repr(np.array([1.0 + 0.0j]))
366
367    # Exact matches:
368    assert m.overloaded4(np.array([1], dtype="double")) == "double"
369    assert m.overloaded4(np.array([1], dtype="longlong")) == "long long"
370    # Non-exact matches requiring conversion.  Since float to integer isn't a
371    # save conversion, it should go to the double overload, but short can go to
372    # either (and so should end up on the first-registered, the long long).
373    assert m.overloaded4(np.array([1], dtype="float32")) == "double"
374    assert m.overloaded4(np.array([1], dtype="short")) == "long long"
375
376    assert m.overloaded5(np.array([1], dtype="double")) == "double"
377    assert m.overloaded5(np.array([1], dtype="uintc")) == "unsigned int"
378    assert m.overloaded5(np.array([1], dtype="float32")) == "unsigned int"
379
380
381def test_greedy_string_overload():
382    """Tests fix for #685 - ndarray shouldn't go to std::string overload"""
383
384    assert m.issue685("abc") == "string"
385    assert m.issue685(np.array([97, 98, 99], dtype="b")) == "array"
386    assert m.issue685(123) == "other"
387
388
389def test_array_unchecked_fixed_dims(msg):
390    z1 = np.array([[1, 2], [3, 4]], dtype="float64")
391    m.proxy_add2(z1, 10)
392    assert np.all(z1 == [[11, 12], [13, 14]])
393
394    with pytest.raises(ValueError) as excinfo:
395        m.proxy_add2(np.array([1.0, 2, 3]), 5.0)
396    assert (
397        msg(excinfo.value) == "array has incorrect number of dimensions: 1; expected 2"
398    )
399
400    expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype="int")
401    assert np.all(m.proxy_init3(3.0) == expect_c)
402    expect_f = np.transpose(expect_c)
403    assert np.all(m.proxy_init3F(3.0) == expect_f)
404
405    assert m.proxy_squared_L2_norm(np.array(range(6))) == 55
406    assert m.proxy_squared_L2_norm(np.array(range(6), dtype="float64")) == 55
407
408    assert m.proxy_auxiliaries2(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32]
409    assert m.proxy_auxiliaries2(z1) == m.array_auxiliaries2(z1)
410
411    assert m.proxy_auxiliaries1_const_ref(z1[0, :])
412    assert m.proxy_auxiliaries2_const_ref(z1)
413
414
415def test_array_unchecked_dyn_dims(msg):
416    z1 = np.array([[1, 2], [3, 4]], dtype="float64")
417    m.proxy_add2_dyn(z1, 10)
418    assert np.all(z1 == [[11, 12], [13, 14]])
419
420    expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype="int")
421    assert np.all(m.proxy_init3_dyn(3.0) == expect_c)
422
423    assert m.proxy_auxiliaries2_dyn(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32]
424    assert m.proxy_auxiliaries2_dyn(z1) == m.array_auxiliaries2(z1)
425
426
427def test_array_failure():
428    with pytest.raises(ValueError) as excinfo:
429        m.array_fail_test()
430    assert str(excinfo.value) == "cannot create a pybind11::array from a nullptr"
431
432    with pytest.raises(ValueError) as excinfo:
433        m.array_t_fail_test()
434    assert str(excinfo.value) == "cannot create a pybind11::array_t from a nullptr"
435
436    with pytest.raises(ValueError) as excinfo:
437        m.array_fail_test_negative_size()
438    assert str(excinfo.value) == "negative dimensions are not allowed"
439
440
441def test_initializer_list():
442    assert m.array_initializer_list1().shape == (1,)
443    assert m.array_initializer_list2().shape == (1, 2)
444    assert m.array_initializer_list3().shape == (1, 2, 3)
445    assert m.array_initializer_list4().shape == (1, 2, 3, 4)
446
447
448def test_array_resize(msg):
449    a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="float64")
450    m.array_reshape2(a)
451    assert a.size == 9
452    assert np.all(a == [[1, 2, 3], [4, 5, 6], [7, 8, 9]])
453
454    # total size change should succced with refcheck off
455    m.array_resize3(a, 4, False)
456    assert a.size == 64
457    # ... and fail with refcheck on
458    try:
459        m.array_resize3(a, 3, True)
460    except ValueError as e:
461        assert str(e).startswith("cannot resize an array")
462    # transposed array doesn't own data
463    b = a.transpose()
464    try:
465        m.array_resize3(b, 3, False)
466    except ValueError as e:
467        assert str(e).startswith("cannot resize this array: it does not own its data")
468    # ... but reshape should be fine
469    m.array_reshape2(b)
470    assert b.shape == (8, 8)
471
472
473@pytest.mark.xfail("env.PYPY")
474def test_array_create_and_resize(msg):
475    a = m.create_and_resize(2)
476    assert a.size == 4
477    assert np.all(a == 42.0)
478
479
480def test_index_using_ellipsis():
481    a = m.index_using_ellipsis(np.zeros((5, 6, 7)))
482    assert a.shape == (6,)
483
484
485@pytest.mark.parametrize("forcecast", [False, True])
486@pytest.mark.parametrize("contiguity", [None, "C", "F"])
487@pytest.mark.parametrize("noconvert", [False, True])
488@pytest.mark.filterwarnings(
489    "ignore:Casting complex values to real discards the imaginary part:numpy.ComplexWarning"
490)
491def test_argument_conversions(forcecast, contiguity, noconvert):
492    function_name = "accept_double"
493    if contiguity == "C":
494        function_name += "_c_style"
495    elif contiguity == "F":
496        function_name += "_f_style"
497    if forcecast:
498        function_name += "_forcecast"
499    if noconvert:
500        function_name += "_noconvert"
501    function = getattr(m, function_name)
502
503    for dtype in [np.dtype("float32"), np.dtype("float64"), np.dtype("complex128")]:
504        for order in ["C", "F"]:
505            for shape in [(2, 2), (1, 3, 1, 1), (1, 1, 1), (0,)]:
506                if not noconvert:
507                    # If noconvert is not passed, only complex128 needs to be truncated and
508                    # "cannot be safely obtained". So without `forcecast`, the argument shouldn't
509                    # be accepted.
510                    should_raise = dtype.name == "complex128" and not forcecast
511                else:
512                    # If noconvert is passed, only float64 and the matching order is accepted.
513                    # If at most one dimension has a size greater than 1, the array is also
514                    # trivially contiguous.
515                    trivially_contiguous = sum(1 for d in shape if d > 1) <= 1
516                    should_raise = dtype.name != "float64" or (
517                        contiguity is not None
518                        and contiguity != order
519                        and not trivially_contiguous
520                    )
521
522                array = np.zeros(shape, dtype=dtype, order=order)
523                if not should_raise:
524                    function(array)
525                else:
526                    with pytest.raises(
527                        TypeError, match="incompatible function arguments"
528                    ):
529                        function(array)
530
531
532@pytest.mark.xfail("env.PYPY")
533def test_dtype_refcount_leak():
534    from sys import getrefcount
535
536    dtype = np.dtype(np.float_)
537    a = np.array([1], dtype=dtype)
538    before = getrefcount(dtype)
539    m.ndim(a)
540    after = getrefcount(dtype)
541    assert after == before
542