1 #include <Eigen/Core>
2 #include <iostream>
3
4 using namespace Eigen;
5
6 // [functor]
7 template<class ArgType, class RowIndexType, class ColIndexType>
8 class indexing_functor {
9 const ArgType &m_arg;
10 const RowIndexType &m_rowIndices;
11 const ColIndexType &m_colIndices;
12 public:
13 typedef Matrix<typename ArgType::Scalar,
14 RowIndexType::SizeAtCompileTime,
15 ColIndexType::SizeAtCompileTime,
16 ArgType::Flags&RowMajorBit?RowMajor:ColMajor,
17 RowIndexType::MaxSizeAtCompileTime,
18 ColIndexType::MaxSizeAtCompileTime> MatrixType;
19
indexing_functor(const ArgType & arg,const RowIndexType & row_indices,const ColIndexType & col_indices)20 indexing_functor(const ArgType& arg, const RowIndexType& row_indices, const ColIndexType& col_indices)
21 : m_arg(arg), m_rowIndices(row_indices), m_colIndices(col_indices)
22 {}
23
operator ()(Index row,Index col) const24 const typename ArgType::Scalar& operator() (Index row, Index col) const {
25 return m_arg(m_rowIndices[row], m_colIndices[col]);
26 }
27 };
28 // [functor]
29
30 // [function]
31 template <class ArgType, class RowIndexType, class ColIndexType>
32 CwiseNullaryOp<indexing_functor<ArgType,RowIndexType,ColIndexType>, typename indexing_functor<ArgType,RowIndexType,ColIndexType>::MatrixType>
indexing(const Eigen::MatrixBase<ArgType> & arg,const RowIndexType & row_indices,const ColIndexType & col_indices)33 indexing(const Eigen::MatrixBase<ArgType>& arg, const RowIndexType& row_indices, const ColIndexType& col_indices)
34 {
35 typedef indexing_functor<ArgType,RowIndexType,ColIndexType> Func;
36 typedef typename Func::MatrixType MatrixType;
37 return MatrixType::NullaryExpr(row_indices.size(), col_indices.size(), Func(arg.derived(), row_indices, col_indices));
38 }
39 // [function]
40
41
main()42 int main()
43 {
44 std::cout << "[main1]\n";
45 Eigen::MatrixXi A = Eigen::MatrixXi::Random(4,4);
46 Array3i ri(1,2,1);
47 ArrayXi ci(6); ci << 3,2,1,0,0,2;
48 Eigen::MatrixXi B = indexing(A, ri, ci);
49 std::cout << "A =" << std::endl;
50 std::cout << A << std::endl << std::endl;
51 std::cout << "A([" << ri.transpose() << "], [" << ci.transpose() << "]) =" << std::endl;
52 std::cout << B << std::endl;
53 std::cout << "[main1]\n";
54
55 std::cout << "[main2]\n";
56 B = indexing(A, ri+1, ci);
57 std::cout << "A(ri+1,ci) =" << std::endl;
58 std::cout << B << std::endl << std::endl;
59 #if __cplusplus >= 201103L
60 B = indexing(A, ArrayXi::LinSpaced(13,0,12).unaryExpr([](int x){return x%4;}), ArrayXi::LinSpaced(4,0,3));
61 std::cout << "A(ArrayXi::LinSpaced(13,0,12).unaryExpr([](int x){return x%4;}), ArrayXi::LinSpaced(4,0,3)) =" << std::endl;
62 std::cout << B << std::endl << std::endl;
63 #endif
64 std::cout << "[main2]\n";
65 }
66
67