• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2     tests/test_numpy_vectorize.cpp -- auto-vectorize functions over NumPy array
3     arguments
4 
5     Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
6 
7     All rights reserved. Use of this source code is governed by a
8     BSD-style license that can be found in the LICENSE file.
9 */
10 
11 #include "pybind11_tests.h"
12 #include <pybind11/numpy.h>
13 
my_func(int x,float y,double z)14 double my_func(int x, float y, double z) {
15     py::print("my_func(x:int={}, y:float={:.0f}, z:float={:.0f})"_s.format(x, y, z));
16     return (float) x*y*z;
17 }
18 
TEST_SUBMODULE(numpy_vectorize,m)19 TEST_SUBMODULE(numpy_vectorize, m) {
20     try { py::module_::import("numpy"); }
21     catch (...) { return; }
22 
23     // test_vectorize, test_docs, test_array_collapse
24     // Vectorize all arguments of a function (though non-vector arguments are also allowed)
25     m.def("vectorized_func", py::vectorize(my_func));
26 
27     // Vectorize a lambda function with a capture object (e.g. to exclude some arguments from the vectorization)
28     m.def("vectorized_func2",
29         [](py::array_t<int> x, py::array_t<float> y, float z) {
30             return py::vectorize([z](int x, float y) { return my_func(x, y, z); })(x, y);
31         }
32     );
33 
34     // Vectorize a complex-valued function
35     m.def("vectorized_func3", py::vectorize(
36         [](std::complex<double> c) { return c * std::complex<double>(2.f); }
37     ));
38 
39     // test_type_selection
40     // NumPy function which only accepts specific data types
41     m.def("selective_func", [](py::array_t<int, py::array::c_style>) { return "Int branch taken."; });
42     m.def("selective_func", [](py::array_t<float, py::array::c_style>) { return "Float branch taken."; });
43     m.def("selective_func", [](py::array_t<std::complex<float>, py::array::c_style>) { return "Complex float branch taken."; });
44 
45 
46     // test_passthrough_arguments
47     // Passthrough test: references and non-pod types should be automatically passed through (in the
48     // function definition below, only `b`, `d`, and `g` are vectorized):
49     struct NonPODClass {
50         NonPODClass(int v) : value{v} {}
51         int value;
52     };
53     py::class_<NonPODClass>(m, "NonPODClass")
54         .def(py::init<int>())
55         .def_readwrite("value", &NonPODClass::value);
56     m.def("vec_passthrough", py::vectorize(
57         [](double *a, double b, py::array_t<double> c, const int &d, int &e, NonPODClass f, const double g) {
58             return *a + b + c.at(0) + d + e + f.value + g;
59         }
60     ));
61 
62     // test_method_vectorization
63     struct VectorizeTestClass {
64         VectorizeTestClass(int v) : value{v} {};
65         float method(int x, float y) { return y + (float) (x + value); }
66         int value = 0;
67     };
68     py::class_<VectorizeTestClass> vtc(m, "VectorizeTestClass");
69     vtc .def(py::init<int>())
70         .def_readwrite("value", &VectorizeTestClass::value);
71 
72     // Automatic vectorizing of methods
73     vtc.def("method", py::vectorize(&VectorizeTestClass::method));
74 
75     // test_trivial_broadcasting
76     // Internal optimization test for whether the input is trivially broadcastable:
77     py::enum_<py::detail::broadcast_trivial>(m, "trivial")
78         .value("f_trivial", py::detail::broadcast_trivial::f_trivial)
79         .value("c_trivial", py::detail::broadcast_trivial::c_trivial)
80         .value("non_trivial", py::detail::broadcast_trivial::non_trivial);
81     m.def("vectorized_is_trivial", [](
82                 py::array_t<int, py::array::forcecast> arg1,
83                 py::array_t<float, py::array::forcecast> arg2,
84                 py::array_t<double, py::array::forcecast> arg3
85                 ) {
86         py::ssize_t ndim;
87         std::vector<py::ssize_t> shape;
88         std::array<py::buffer_info, 3> buffers {{ arg1.request(), arg2.request(), arg3.request() }};
89         return py::detail::broadcast(buffers, ndim, shape);
90     });
91 
92     m.def("add_to", py::vectorize([](NonPODClass& x, int a) { x.value += a; }));
93 }
94