• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Jim Bosch & Ankit Daftery 2010-2012.
2 // Copyright Stefan Seefeld 2016.
3 // Distributed under the Boost Software License, Version 1.0.
4 // (See accompanying file LICENSE_1_0.txt or copy at
5 // http://www.boost.org/LICENSE_1_0.txt)
6 
7 #include <boost/python/numpy.hpp>
8 #include <boost/mpl/vector.hpp>
9 #include <boost/mpl/vector_c.hpp>
10 
11 namespace p = boost::python;
12 namespace np = boost::python::numpy;
13 
14 struct ArrayFiller
15 {
16 
17   typedef boost::mpl::vector< short, int, float, std::complex<double> > TypeSequence;
18   typedef boost::mpl::vector_c< int, 1, 2 > DimSequence;
19 
ArrayFillerArrayFiller20   explicit ArrayFiller(np::ndarray const & arg) : argument(arg) {}
21 
22   template <typename T, int N>
applyArrayFiller23   void apply() const
24   {
25     if (N == 1)
26     {
27       char * p = argument.get_data();
28       int stride = argument.strides(0);
29       int size = argument.shape(0);
30       for (int n = 0; n != size; ++n, p += stride)
31 	*reinterpret_cast<T*>(p) = static_cast<T>(n);
32     }
33     else
34     {
35       char * row_p = argument.get_data();
36       int row_stride = argument.strides(0);
37       int col_stride = argument.strides(1);
38       int rows = argument.shape(0);
39       int cols = argument.shape(1);
40       int i = 0;
41       for (int n = 0; n != rows; ++n, row_p += row_stride)
42       {
43 	char * col_p = row_p;
44 	for (int m = 0; m != cols; ++i, ++m, col_p += col_stride)
45 	  *reinterpret_cast<T*>(col_p) = static_cast<T>(i);
46       }
47     }
48   }
49 
50   np::ndarray argument;
51 };
52 
fill(np::ndarray const & arg)53 void fill(np::ndarray const & arg)
54 {
55   ArrayFiller filler(arg);
56   np::invoke_matching_array<ArrayFiller::TypeSequence, ArrayFiller::DimSequence >(arg, filler);
57 }
58 
BOOST_PYTHON_MODULE(templates_ext)59 BOOST_PYTHON_MODULE(templates_ext)
60 {
61   np::initialize();
62   p::def("fill", fill);
63 }
64