IndexedView.h 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #ifndef EIGEN_INDEXED_VIEW_H
  10. #define EIGEN_INDEXED_VIEW_H
  11. namespace Eigen {
  12. namespace internal {
  13. template<typename XprType, typename RowIndices, typename ColIndices>
  14. struct traits<IndexedView<XprType, RowIndices, ColIndices> >
  15. : traits<XprType>
  16. {
  17. enum {
  18. RowsAtCompileTime = int(array_size<RowIndices>::value),
  19. ColsAtCompileTime = int(array_size<ColIndices>::value),
  20. MaxRowsAtCompileTime = RowsAtCompileTime != Dynamic ? int(RowsAtCompileTime) : Dynamic,
  21. MaxColsAtCompileTime = ColsAtCompileTime != Dynamic ? int(ColsAtCompileTime) : Dynamic,
  22. XprTypeIsRowMajor = (int(traits<XprType>::Flags)&RowMajorBit) != 0,
  23. IsRowMajor = (MaxRowsAtCompileTime==1&&MaxColsAtCompileTime!=1) ? 1
  24. : (MaxColsAtCompileTime==1&&MaxRowsAtCompileTime!=1) ? 0
  25. : XprTypeIsRowMajor,
  26. RowIncr = int(get_compile_time_incr<RowIndices>::value),
  27. ColIncr = int(get_compile_time_incr<ColIndices>::value),
  28. InnerIncr = IsRowMajor ? ColIncr : RowIncr,
  29. OuterIncr = IsRowMajor ? RowIncr : ColIncr,
  30. HasSameStorageOrderAsXprType = (IsRowMajor == XprTypeIsRowMajor),
  31. XprInnerStride = HasSameStorageOrderAsXprType ? int(inner_stride_at_compile_time<XprType>::ret) : int(outer_stride_at_compile_time<XprType>::ret),
  32. XprOuterstride = HasSameStorageOrderAsXprType ? int(outer_stride_at_compile_time<XprType>::ret) : int(inner_stride_at_compile_time<XprType>::ret),
  33. InnerSize = XprTypeIsRowMajor ? ColsAtCompileTime : RowsAtCompileTime,
  34. IsBlockAlike = InnerIncr==1 && OuterIncr==1,
  35. IsInnerPannel = HasSameStorageOrderAsXprType && is_same<AllRange<InnerSize>,typename conditional<XprTypeIsRowMajor,ColIndices,RowIndices>::type>::value,
  36. InnerStrideAtCompileTime = InnerIncr<0 || InnerIncr==DynamicIndex || XprInnerStride==Dynamic ? Dynamic : XprInnerStride * InnerIncr,
  37. OuterStrideAtCompileTime = OuterIncr<0 || OuterIncr==DynamicIndex || XprOuterstride==Dynamic ? Dynamic : XprOuterstride * OuterIncr,
  38. ReturnAsScalar = is_same<RowIndices,SingleRange>::value && is_same<ColIndices,SingleRange>::value,
  39. ReturnAsBlock = (!ReturnAsScalar) && IsBlockAlike,
  40. ReturnAsIndexedView = (!ReturnAsScalar) && (!ReturnAsBlock),
  41. // FIXME we deal with compile-time strides if and only if we have DirectAccessBit flag,
  42. // but this is too strict regarding negative strides...
  43. DirectAccessMask = (int(InnerIncr)!=UndefinedIncr && int(OuterIncr)!=UndefinedIncr && InnerIncr>=0 && OuterIncr>=0) ? DirectAccessBit : 0,
  44. FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
  45. FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
  46. FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0,
  47. Flags = (traits<XprType>::Flags & (HereditaryBits | DirectAccessMask )) | FlagsLvalueBit | FlagsRowMajorBit | FlagsLinearAccessBit
  48. };
  49. typedef Block<XprType,RowsAtCompileTime,ColsAtCompileTime,IsInnerPannel> BlockType;
  50. };
  51. }
  52. template<typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
  53. class IndexedViewImpl;
  54. /** \class IndexedView
  55. * \ingroup Core_Module
  56. *
  57. * \brief Expression of a non-sequential sub-matrix defined by arbitrary sequences of row and column indices
  58. *
  59. * \tparam XprType the type of the expression in which we are taking the intersections of sub-rows and sub-columns
  60. * \tparam RowIndices the type of the object defining the sequence of row indices
  61. * \tparam ColIndices the type of the object defining the sequence of column indices
  62. *
  63. * This class represents an expression of a sub-matrix (or sub-vector) defined as the intersection
  64. * of sub-sets of rows and columns, that are themself defined by generic sequences of row indices \f$ \{r_0,r_1,..r_{m-1}\} \f$
  65. * and column indices \f$ \{c_0,c_1,..c_{n-1} \}\f$. Let \f$ A \f$ be the nested matrix, then the resulting matrix \f$ B \f$ has \c m
  66. * rows and \c n columns, and its entries are given by: \f$ B(i,j) = A(r_i,c_j) \f$.
  67. *
  68. * The \c RowIndices and \c ColIndices types must be compatible with the following API:
  69. * \code
  70. * <integral type> operator[](Index) const;
  71. * Index size() const;
  72. * \endcode
  73. *
  74. * Typical supported types thus include:
  75. * - std::vector<int>
  76. * - std::valarray<int>
  77. * - std::array<int>
  78. * - Plain C arrays: int[N]
  79. * - Eigen::ArrayXi
  80. * - decltype(ArrayXi::LinSpaced(...))
  81. * - Any view/expressions of the previous types
  82. * - Eigen::ArithmeticSequence
  83. * - Eigen::internal::AllRange (helper for Eigen::all)
  84. * - Eigen::internal::SingleRange (helper for single index)
  85. * - etc.
  86. *
  87. * In typical usages of %Eigen, this class should never be used directly. It is the return type of
  88. * DenseBase::operator()(const RowIndices&, const ColIndices&).
  89. *
  90. * \sa class Block
  91. */
  92. template<typename XprType, typename RowIndices, typename ColIndices>
  93. class IndexedView : public IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind>
  94. {
  95. public:
  96. typedef typename IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind>::Base Base;
  97. EIGEN_GENERIC_PUBLIC_INTERFACE(IndexedView)
  98. EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView)
  99. typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
  100. typedef typename internal::remove_all<XprType>::type NestedExpression;
  101. template<typename T0, typename T1>
  102. IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices)
  103. : m_xpr(xpr), m_rowIndices(rowIndices), m_colIndices(colIndices)
  104. {}
  105. /** \returns number of rows */
  106. Index rows() const { return internal::size(m_rowIndices); }
  107. /** \returns number of columns */
  108. Index cols() const { return internal::size(m_colIndices); }
  109. /** \returns the nested expression */
  110. const typename internal::remove_all<XprType>::type&
  111. nestedExpression() const { return m_xpr; }
  112. /** \returns the nested expression */
  113. typename internal::remove_reference<XprType>::type&
  114. nestedExpression() { return m_xpr; }
  115. /** \returns a const reference to the object storing/generating the row indices */
  116. const RowIndices& rowIndices() const { return m_rowIndices; }
  117. /** \returns a const reference to the object storing/generating the column indices */
  118. const ColIndices& colIndices() const { return m_colIndices; }
  119. protected:
  120. MatrixTypeNested m_xpr;
  121. RowIndices m_rowIndices;
  122. ColIndices m_colIndices;
  123. };
  124. // Generic API dispatcher
  125. template<typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
  126. class IndexedViewImpl
  127. : public internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices> >::type
  128. {
  129. public:
  130. typedef typename internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices> >::type Base;
  131. };
  132. namespace internal {
  133. template<typename ArgType, typename RowIndices, typename ColIndices>
  134. struct unary_evaluator<IndexedView<ArgType, RowIndices, ColIndices>, IndexBased>
  135. : evaluator_base<IndexedView<ArgType, RowIndices, ColIndices> >
  136. {
  137. typedef IndexedView<ArgType, RowIndices, ColIndices> XprType;
  138. enum {
  139. CoeffReadCost = evaluator<ArgType>::CoeffReadCost /* TODO + cost of row/col index */,
  140. FlagsLinearAccessBit = (traits<XprType>::RowsAtCompileTime == 1 || traits<XprType>::ColsAtCompileTime == 1) ? LinearAccessBit : 0,
  141. FlagsRowMajorBit = traits<XprType>::FlagsRowMajorBit,
  142. Flags = (evaluator<ArgType>::Flags & (HereditaryBits & ~RowMajorBit /*| LinearAccessBit | DirectAccessBit*/)) | FlagsLinearAccessBit | FlagsRowMajorBit,
  143. Alignment = 0
  144. };
  145. EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr)
  146. {
  147. EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
  148. }
  149. typedef typename XprType::Scalar Scalar;
  150. typedef typename XprType::CoeffReturnType CoeffReturnType;
  151. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
  152. CoeffReturnType coeff(Index row, Index col) const
  153. {
  154. return m_argImpl.coeff(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
  155. }
  156. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
  157. Scalar& coeffRef(Index row, Index col)
  158. {
  159. return m_argImpl.coeffRef(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
  160. }
  161. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
  162. Scalar& coeffRef(Index index)
  163. {
  164. EIGEN_STATIC_ASSERT_LVALUE(XprType)
  165. Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
  166. Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
  167. return m_argImpl.coeffRef( m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
  168. }
  169. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
  170. const Scalar& coeffRef(Index index) const
  171. {
  172. Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
  173. Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
  174. return m_argImpl.coeffRef( m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
  175. }
  176. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
  177. const CoeffReturnType coeff(Index index) const
  178. {
  179. Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
  180. Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
  181. return m_argImpl.coeff( m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
  182. }
  183. protected:
  184. evaluator<ArgType> m_argImpl;
  185. const XprType& m_xpr;
  186. };
  187. } // end namespace internal
  188. } // end namespace Eigen
  189. #endif // EIGEN_INDEXED_VIEW_H