Visitor.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #ifndef EIGEN_VISITOR_H
  10. #define EIGEN_VISITOR_H
  11. namespace Eigen {
  12. namespace internal {
  13. template<typename Visitor, typename Derived, int UnrollCount>
  14. struct visitor_impl
  15. {
  16. enum {
  17. col = (UnrollCount-1) / Derived::RowsAtCompileTime,
  18. row = (UnrollCount-1) % Derived::RowsAtCompileTime
  19. };
  20. EIGEN_DEVICE_FUNC
  21. static inline void run(const Derived &mat, Visitor& visitor)
  22. {
  23. visitor_impl<Visitor, Derived, UnrollCount-1>::run(mat, visitor);
  24. visitor(mat.coeff(row, col), row, col);
  25. }
  26. };
  27. template<typename Visitor, typename Derived>
  28. struct visitor_impl<Visitor, Derived, 1>
  29. {
  30. EIGEN_DEVICE_FUNC
  31. static inline void run(const Derived &mat, Visitor& visitor)
  32. {
  33. return visitor.init(mat.coeff(0, 0), 0, 0);
  34. }
  35. };
  36. // This specialization enables visitors on empty matrices at compile-time
  37. template<typename Visitor, typename Derived>
  38. struct visitor_impl<Visitor, Derived, 0> {
  39. EIGEN_DEVICE_FUNC
  40. static inline void run(const Derived &/*mat*/, Visitor& /*visitor*/)
  41. {}
  42. };
  43. template<typename Visitor, typename Derived>
  44. struct visitor_impl<Visitor, Derived, Dynamic>
  45. {
  46. EIGEN_DEVICE_FUNC
  47. static inline void run(const Derived& mat, Visitor& visitor)
  48. {
  49. visitor.init(mat.coeff(0,0), 0, 0);
  50. for(Index i = 1; i < mat.rows(); ++i)
  51. visitor(mat.coeff(i, 0), i, 0);
  52. for(Index j = 1; j < mat.cols(); ++j)
  53. for(Index i = 0; i < mat.rows(); ++i)
  54. visitor(mat.coeff(i, j), i, j);
  55. }
  56. };
  57. // evaluator adaptor
  58. template<typename XprType>
  59. class visitor_evaluator
  60. {
  61. public:
  62. EIGEN_DEVICE_FUNC
  63. explicit visitor_evaluator(const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) {}
  64. typedef typename XprType::Scalar Scalar;
  65. typedef typename XprType::CoeffReturnType CoeffReturnType;
  66. enum {
  67. RowsAtCompileTime = XprType::RowsAtCompileTime,
  68. CoeffReadCost = internal::evaluator<XprType>::CoeffReadCost
  69. };
  70. EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_xpr.rows(); }
  71. EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_xpr.cols(); }
  72. EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index size() const EIGEN_NOEXCEPT { return m_xpr.size(); }
  73. EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const
  74. { return m_evaluator.coeff(row, col); }
  75. protected:
  76. internal::evaluator<XprType> m_evaluator;
  77. const XprType &m_xpr;
  78. };
  79. } // end namespace internal
  80. /** Applies the visitor \a visitor to the whole coefficients of the matrix or vector.
  81. *
  82. * The template parameter \a Visitor is the type of the visitor and provides the following interface:
  83. * \code
  84. * struct MyVisitor {
  85. * // called for the first coefficient
  86. * void init(const Scalar& value, Index i, Index j);
  87. * // called for all other coefficients
  88. * void operator() (const Scalar& value, Index i, Index j);
  89. * };
  90. * \endcode
  91. *
  92. * \note compared to one or two \em for \em loops, visitors offer automatic
  93. * unrolling for small fixed size matrix.
  94. *
  95. * \note if the matrix is empty, then the visitor is left unchanged.
  96. *
  97. * \sa minCoeff(Index*,Index*), maxCoeff(Index*,Index*), DenseBase::redux()
  98. */
  99. template<typename Derived>
  100. template<typename Visitor>
  101. EIGEN_DEVICE_FUNC
  102. void DenseBase<Derived>::visit(Visitor& visitor) const
  103. {
  104. if(size()==0)
  105. return;
  106. typedef typename internal::visitor_evaluator<Derived> ThisEvaluator;
  107. ThisEvaluator thisEval(derived());
  108. enum {
  109. unroll = SizeAtCompileTime != Dynamic
  110. && SizeAtCompileTime * int(ThisEvaluator::CoeffReadCost) + (SizeAtCompileTime-1) * int(internal::functor_traits<Visitor>::Cost) <= EIGEN_UNROLLING_LIMIT
  111. };
  112. return internal::visitor_impl<Visitor, ThisEvaluator, unroll ? int(SizeAtCompileTime) : Dynamic>::run(thisEval, visitor);
  113. }
  114. namespace internal {
  115. /** \internal
  116. * \brief Base class to implement min and max visitors
  117. */
  118. template <typename Derived>
  119. struct coeff_visitor
  120. {
  121. // default initialization to avoid countless invalid maybe-uninitialized warnings by gcc
  122. EIGEN_DEVICE_FUNC
  123. coeff_visitor() : row(-1), col(-1), res(0) {}
  124. typedef typename Derived::Scalar Scalar;
  125. Index row, col;
  126. Scalar res;
  127. EIGEN_DEVICE_FUNC
  128. inline void init(const Scalar& value, Index i, Index j)
  129. {
  130. res = value;
  131. row = i;
  132. col = j;
  133. }
  134. };
  135. /** \internal
  136. * \brief Visitor computing the min coefficient with its value and coordinates
  137. *
  138. * \sa DenseBase::minCoeff(Index*, Index*)
  139. */
  140. template <typename Derived, int NaNPropagation>
  141. struct min_coeff_visitor : coeff_visitor<Derived>
  142. {
  143. typedef typename Derived::Scalar Scalar;
  144. EIGEN_DEVICE_FUNC
  145. void operator() (const Scalar& value, Index i, Index j)
  146. {
  147. if(value < this->res)
  148. {
  149. this->res = value;
  150. this->row = i;
  151. this->col = j;
  152. }
  153. }
  154. };
  155. template <typename Derived>
  156. struct min_coeff_visitor<Derived, PropagateNumbers> : coeff_visitor<Derived>
  157. {
  158. typedef typename Derived::Scalar Scalar;
  159. EIGEN_DEVICE_FUNC
  160. void operator() (const Scalar& value, Index i, Index j)
  161. {
  162. if((numext::isnan)(this->res) || (!(numext::isnan)(value) && value < this->res))
  163. {
  164. this->res = value;
  165. this->row = i;
  166. this->col = j;
  167. }
  168. }
  169. };
  170. template <typename Derived>
  171. struct min_coeff_visitor<Derived, PropagateNaN> : coeff_visitor<Derived>
  172. {
  173. typedef typename Derived::Scalar Scalar;
  174. EIGEN_DEVICE_FUNC
  175. void operator() (const Scalar& value, Index i, Index j)
  176. {
  177. if((numext::isnan)(value) || value < this->res)
  178. {
  179. this->res = value;
  180. this->row = i;
  181. this->col = j;
  182. }
  183. }
  184. };
  185. template<typename Scalar, int NaNPropagation>
  186. struct functor_traits<min_coeff_visitor<Scalar, NaNPropagation> > {
  187. enum {
  188. Cost = NumTraits<Scalar>::AddCost
  189. };
  190. };
  191. /** \internal
  192. * \brief Visitor computing the max coefficient with its value and coordinates
  193. *
  194. * \sa DenseBase::maxCoeff(Index*, Index*)
  195. */
  196. template <typename Derived, int NaNPropagation>
  197. struct max_coeff_visitor : coeff_visitor<Derived>
  198. {
  199. typedef typename Derived::Scalar Scalar;
  200. EIGEN_DEVICE_FUNC
  201. void operator() (const Scalar& value, Index i, Index j)
  202. {
  203. if(value > this->res)
  204. {
  205. this->res = value;
  206. this->row = i;
  207. this->col = j;
  208. }
  209. }
  210. };
  211. template <typename Derived>
  212. struct max_coeff_visitor<Derived, PropagateNumbers> : coeff_visitor<Derived>
  213. {
  214. typedef typename Derived::Scalar Scalar;
  215. EIGEN_DEVICE_FUNC
  216. void operator() (const Scalar& value, Index i, Index j)
  217. {
  218. if((numext::isnan)(this->res) || (!(numext::isnan)(value) && value > this->res))
  219. {
  220. this->res = value;
  221. this->row = i;
  222. this->col = j;
  223. }
  224. }
  225. };
  226. template <typename Derived>
  227. struct max_coeff_visitor<Derived, PropagateNaN> : coeff_visitor<Derived>
  228. {
  229. typedef typename Derived::Scalar Scalar;
  230. EIGEN_DEVICE_FUNC
  231. void operator() (const Scalar& value, Index i, Index j)
  232. {
  233. if((numext::isnan)(value) || value > this->res)
  234. {
  235. this->res = value;
  236. this->row = i;
  237. this->col = j;
  238. }
  239. }
  240. };
  241. template<typename Scalar, int NaNPropagation>
  242. struct functor_traits<max_coeff_visitor<Scalar, NaNPropagation> > {
  243. enum {
  244. Cost = NumTraits<Scalar>::AddCost
  245. };
  246. };
  247. } // end namespace internal
  248. /** \fn DenseBase<Derived>::minCoeff(IndexType* rowId, IndexType* colId) const
  249. * \returns the minimum of all coefficients of *this and puts in *row and *col its location.
  250. *
  251. * In case \c *this contains NaN, NaNPropagation determines the behavior:
  252. * NaNPropagation == PropagateFast : undefined
  253. * NaNPropagation == PropagateNaN : result is NaN
  254. * NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
  255. * \warning the matrix must be not empty, otherwise an assertion is triggered.
  256. *
  257. * \sa DenseBase::minCoeff(Index*), DenseBase::maxCoeff(Index*,Index*), DenseBase::visit(), DenseBase::minCoeff()
  258. */
  259. template<typename Derived>
  260. template<int NaNPropagation, typename IndexType>
  261. EIGEN_DEVICE_FUNC
  262. typename internal::traits<Derived>::Scalar
  263. DenseBase<Derived>::minCoeff(IndexType* rowId, IndexType* colId) const
  264. {
  265. eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
  266. internal::min_coeff_visitor<Derived, NaNPropagation> minVisitor;
  267. this->visit(minVisitor);
  268. *rowId = minVisitor.row;
  269. if (colId) *colId = minVisitor.col;
  270. return minVisitor.res;
  271. }
  272. /** \returns the minimum of all coefficients of *this and puts in *index its location.
  273. *
  274. * In case \c *this contains NaN, NaNPropagation determines the behavior:
  275. * NaNPropagation == PropagateFast : undefined
  276. * NaNPropagation == PropagateNaN : result is NaN
  277. * NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
  278. * \warning the matrix must be not empty, otherwise an assertion is triggered.
  279. *
  280. * \sa DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::maxCoeff(IndexType*,IndexType*), DenseBase::visit(), DenseBase::minCoeff()
  281. */
  282. template<typename Derived>
  283. template<int NaNPropagation, typename IndexType>
  284. EIGEN_DEVICE_FUNC
  285. typename internal::traits<Derived>::Scalar
  286. DenseBase<Derived>::minCoeff(IndexType* index) const
  287. {
  288. eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
  289. EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
  290. internal::min_coeff_visitor<Derived, NaNPropagation> minVisitor;
  291. this->visit(minVisitor);
  292. *index = IndexType((RowsAtCompileTime==1) ? minVisitor.col : minVisitor.row);
  293. return minVisitor.res;
  294. }
  295. /** \fn DenseBase<Derived>::maxCoeff(IndexType* rowId, IndexType* colId) const
  296. * \returns the maximum of all coefficients of *this and puts in *row and *col its location.
  297. *
  298. * In case \c *this contains NaN, NaNPropagation determines the behavior:
  299. * NaNPropagation == PropagateFast : undefined
  300. * NaNPropagation == PropagateNaN : result is NaN
  301. * NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
  302. * \warning the matrix must be not empty, otherwise an assertion is triggered.
  303. *
  304. * \sa DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::visit(), DenseBase::maxCoeff()
  305. */
  306. template<typename Derived>
  307. template<int NaNPropagation, typename IndexType>
  308. EIGEN_DEVICE_FUNC
  309. typename internal::traits<Derived>::Scalar
  310. DenseBase<Derived>::maxCoeff(IndexType* rowPtr, IndexType* colPtr) const
  311. {
  312. eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
  313. internal::max_coeff_visitor<Derived, NaNPropagation> maxVisitor;
  314. this->visit(maxVisitor);
  315. *rowPtr = maxVisitor.row;
  316. if (colPtr) *colPtr = maxVisitor.col;
  317. return maxVisitor.res;
  318. }
  319. /** \returns the maximum of all coefficients of *this and puts in *index its location.
  320. *
  321. * In case \c *this contains NaN, NaNPropagation determines the behavior:
  322. * NaNPropagation == PropagateFast : undefined
  323. * NaNPropagation == PropagateNaN : result is NaN
  324. * NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
  325. * \warning the matrix must be not empty, otherwise an assertion is triggered.
  326. *
  327. * \sa DenseBase::maxCoeff(IndexType*,IndexType*), DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::visitor(), DenseBase::maxCoeff()
  328. */
  329. template<typename Derived>
  330. template<int NaNPropagation, typename IndexType>
  331. EIGEN_DEVICE_FUNC
  332. typename internal::traits<Derived>::Scalar
  333. DenseBase<Derived>::maxCoeff(IndexType* index) const
  334. {
  335. eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
  336. EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
  337. internal::max_coeff_visitor<Derived, NaNPropagation> maxVisitor;
  338. this->visit(maxVisitor);
  339. *index = (RowsAtCompileTime==1) ? maxVisitor.col : maxVisitor.row;
  340. return maxVisitor.res;
  341. }
  342. } // end namespace Eigen
  343. #endif // EIGEN_VISITOR_H