• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_VISITOR_H
11 #define EIGEN_VISITOR_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template<typename Visitor, typename Derived, int UnrollCount>
18 struct visitor_impl
19 {
20   enum {
21     col = (UnrollCount-1) / Derived::RowsAtCompileTime,
22     row = (UnrollCount-1) % Derived::RowsAtCompileTime
23   };
24 
25   EIGEN_DEVICE_FUNC
runvisitor_impl26   static inline void run(const Derived &mat, Visitor& visitor)
27   {
28     visitor_impl<Visitor, Derived, UnrollCount-1>::run(mat, visitor);
29     visitor(mat.coeff(row, col), row, col);
30   }
31 };
32 
33 template<typename Visitor, typename Derived>
34 struct visitor_impl<Visitor, Derived, 1>
35 {
36   EIGEN_DEVICE_FUNC
37   static inline void run(const Derived &mat, Visitor& visitor)
38   {
39     return visitor.init(mat.coeff(0, 0), 0, 0);
40   }
41 };
42 
43 template<typename Visitor, typename Derived>
44 struct visitor_impl<Visitor, Derived, Dynamic>
45 {
46   EIGEN_DEVICE_FUNC
47   static inline void run(const Derived& mat, Visitor& visitor)
48   {
49     visitor.init(mat.coeff(0,0), 0, 0);
50     for(Index i = 1; i < mat.rows(); ++i)
51       visitor(mat.coeff(i, 0), i, 0);
52     for(Index j = 1; j < mat.cols(); ++j)
53       for(Index i = 0; i < mat.rows(); ++i)
54         visitor(mat.coeff(i, j), i, j);
55   }
56 };
57 
58 // evaluator adaptor
59 template<typename XprType>
60 class visitor_evaluator
61 {
62 public:
63   EIGEN_DEVICE_FUNC
64   explicit visitor_evaluator(const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) {}
65 
66   typedef typename XprType::Scalar Scalar;
67   typedef typename XprType::CoeffReturnType CoeffReturnType;
68 
69   enum {
70     RowsAtCompileTime = XprType::RowsAtCompileTime,
71     CoeffReadCost = internal::evaluator<XprType>::CoeffReadCost
72   };
73 
74   EIGEN_DEVICE_FUNC Index rows() const { return m_xpr.rows(); }
75   EIGEN_DEVICE_FUNC Index cols() const { return m_xpr.cols(); }
76   EIGEN_DEVICE_FUNC Index size() const { return m_xpr.size(); }
77 
78   EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const
79   { return m_evaluator.coeff(row, col); }
80 
81 protected:
82   internal::evaluator<XprType> m_evaluator;
83   const XprType &m_xpr;
84 };
85 } // end namespace internal
86 
87 /** Applies the visitor \a visitor to the whole coefficients of the matrix or vector.
88   *
89   * The template parameter \a Visitor is the type of the visitor and provides the following interface:
90   * \code
91   * struct MyVisitor {
92   *   // called for the first coefficient
93   *   void init(const Scalar& value, Index i, Index j);
94   *   // called for all other coefficients
95   *   void operator() (const Scalar& value, Index i, Index j);
96   * };
97   * \endcode
98   *
99   * \note compared to one or two \em for \em loops, visitors offer automatic
100   * unrolling for small fixed size matrix.
101   *
102   * \sa minCoeff(Index*,Index*), maxCoeff(Index*,Index*), DenseBase::redux()
103   */
104 template<typename Derived>
105 template<typename Visitor>
106 EIGEN_DEVICE_FUNC
107 void DenseBase<Derived>::visit(Visitor& visitor) const
108 {
109   typedef typename internal::visitor_evaluator<Derived> ThisEvaluator;
110   ThisEvaluator thisEval(derived());
111 
112   enum {
113     unroll =  SizeAtCompileTime != Dynamic
114            && SizeAtCompileTime * ThisEvaluator::CoeffReadCost + (SizeAtCompileTime-1) * internal::functor_traits<Visitor>::Cost <= EIGEN_UNROLLING_LIMIT
115   };
116   return internal::visitor_impl<Visitor, ThisEvaluator, unroll ? int(SizeAtCompileTime) : Dynamic>::run(thisEval, visitor);
117 }
118 
119 namespace internal {
120 
121 /** \internal
122   * \brief Base class to implement min and max visitors
123   */
124 template <typename Derived>
125 struct coeff_visitor
126 {
127   typedef typename Derived::Scalar Scalar;
128   Index row, col;
129   Scalar res;
130   EIGEN_DEVICE_FUNC
131   inline void init(const Scalar& value, Index i, Index j)
132   {
133     res = value;
134     row = i;
135     col = j;
136   }
137 };
138 
139 /** \internal
140   * \brief Visitor computing the min coefficient with its value and coordinates
141   *
142   * \sa DenseBase::minCoeff(Index*, Index*)
143   */
144 template <typename Derived>
145 struct min_coeff_visitor : coeff_visitor<Derived>
146 {
147   typedef typename Derived::Scalar Scalar;
148   EIGEN_DEVICE_FUNC
149   void operator() (const Scalar& value, Index i, Index j)
150   {
151     if(value < this->res)
152     {
153       this->res = value;
154       this->row = i;
155       this->col = j;
156     }
157   }
158 };
159 
160 template<typename Scalar>
161 struct functor_traits<min_coeff_visitor<Scalar> > {
162   enum {
163     Cost = NumTraits<Scalar>::AddCost
164   };
165 };
166 
167 /** \internal
168   * \brief Visitor computing the max coefficient with its value and coordinates
169   *
170   * \sa DenseBase::maxCoeff(Index*, Index*)
171   */
172 template <typename Derived>
173 struct max_coeff_visitor : coeff_visitor<Derived>
174 {
175   typedef typename Derived::Scalar Scalar;
176   EIGEN_DEVICE_FUNC
177   void operator() (const Scalar& value, Index i, Index j)
178   {
179     if(value > this->res)
180     {
181       this->res = value;
182       this->row = i;
183       this->col = j;
184     }
185   }
186 };
187 
188 template<typename Scalar>
189 struct functor_traits<max_coeff_visitor<Scalar> > {
190   enum {
191     Cost = NumTraits<Scalar>::AddCost
192   };
193 };
194 
195 } // end namespace internal
196 
197 /** \fn DenseBase<Derived>::minCoeff(IndexType* rowId, IndexType* colId) const
198   * \returns the minimum of all coefficients of *this and puts in *row and *col its location.
199   * \warning the result is undefined if \c *this contains NaN.
200   *
201   * \sa DenseBase::minCoeff(Index*), DenseBase::maxCoeff(Index*,Index*), DenseBase::visit(), DenseBase::minCoeff()
202   */
203 template<typename Derived>
204 template<typename IndexType>
205 EIGEN_DEVICE_FUNC
206 typename internal::traits<Derived>::Scalar
207 DenseBase<Derived>::minCoeff(IndexType* rowId, IndexType* colId) const
208 {
209   internal::min_coeff_visitor<Derived> minVisitor;
210   this->visit(minVisitor);
211   *rowId = minVisitor.row;
212   if (colId) *colId = minVisitor.col;
213   return minVisitor.res;
214 }
215 
216 /** \returns the minimum of all coefficients of *this and puts in *index its location.
217   * \warning the result is undefined if \c *this contains NaN.
218   *
219   * \sa DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::maxCoeff(IndexType*,IndexType*), DenseBase::visit(), DenseBase::minCoeff()
220   */
221 template<typename Derived>
222 template<typename IndexType>
223 EIGEN_DEVICE_FUNC
224 typename internal::traits<Derived>::Scalar
225 DenseBase<Derived>::minCoeff(IndexType* index) const
226 {
227   EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
228   internal::min_coeff_visitor<Derived> minVisitor;
229   this->visit(minVisitor);
230   *index = IndexType((RowsAtCompileTime==1) ? minVisitor.col : minVisitor.row);
231   return minVisitor.res;
232 }
233 
234 /** \fn DenseBase<Derived>::maxCoeff(IndexType* rowId, IndexType* colId) const
235   * \returns the maximum of all coefficients of *this and puts in *row and *col its location.
236   * \warning the result is undefined if \c *this contains NaN.
237   *
238   * \sa DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::visit(), DenseBase::maxCoeff()
239   */
240 template<typename Derived>
241 template<typename IndexType>
242 EIGEN_DEVICE_FUNC
243 typename internal::traits<Derived>::Scalar
244 DenseBase<Derived>::maxCoeff(IndexType* rowPtr, IndexType* colPtr) const
245 {
246   internal::max_coeff_visitor<Derived> maxVisitor;
247   this->visit(maxVisitor);
248   *rowPtr = maxVisitor.row;
249   if (colPtr) *colPtr = maxVisitor.col;
250   return maxVisitor.res;
251 }
252 
253 /** \returns the maximum of all coefficients of *this and puts in *index its location.
254   * \warning the result is undefined if \c *this contains NaN.
255   *
256   * \sa DenseBase::maxCoeff(IndexType*,IndexType*), DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::visitor(), DenseBase::maxCoeff()
257   */
258 template<typename Derived>
259 template<typename IndexType>
260 EIGEN_DEVICE_FUNC
261 typename internal::traits<Derived>::Scalar
262 DenseBase<Derived>::maxCoeff(IndexType* index) const
263 {
264   EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
265   internal::max_coeff_visitor<Derived> maxVisitor;
266   this->visit(maxVisitor);
267   *index = (RowsAtCompileTime==1) ? maxVisitor.col : maxVisitor.row;
268   return maxVisitor.res;
269 }
270 
271 } // end namespace Eigen
272 
273 #endif // EIGEN_VISITOR_H
274