SolverBase.h 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #ifndef EIGEN_SOLVERBASE_H
  10. #define EIGEN_SOLVERBASE_H
  11. namespace Eigen {
  12. namespace internal {
  13. template<typename Derived>
  14. struct solve_assertion {
  15. template<bool Transpose_, typename Rhs>
  16. static void run(const Derived& solver, const Rhs& b) { solver.template _check_solve_assertion<Transpose_>(b); }
  17. };
  18. template<typename Derived>
  19. struct solve_assertion<Transpose<Derived> >
  20. {
  21. typedef Transpose<Derived> type;
  22. template<bool Transpose_, typename Rhs>
  23. static void run(const type& transpose, const Rhs& b)
  24. {
  25. internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<true>(transpose.nestedExpression(), b);
  26. }
  27. };
  28. template<typename Scalar, typename Derived>
  29. struct solve_assertion<CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > >
  30. {
  31. typedef CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > type;
  32. template<bool Transpose_, typename Rhs>
  33. static void run(const type& adjoint, const Rhs& b)
  34. {
  35. internal::solve_assertion<typename internal::remove_all<Transpose<Derived> >::type>::template run<true>(adjoint.nestedExpression(), b);
  36. }
  37. };
  38. } // end namespace internal
  39. /** \class SolverBase
  40. * \brief A base class for matrix decomposition and solvers
  41. *
  42. * \tparam Derived the actual type of the decomposition/solver.
  43. *
  44. * Any matrix decomposition inheriting this base class provide the following API:
  45. *
  46. * \code
  47. * MatrixType A, b, x;
  48. * DecompositionType dec(A);
  49. * x = dec.solve(b); // solve A * x = b
  50. * x = dec.transpose().solve(b); // solve A^T * x = b
  51. * x = dec.adjoint().solve(b); // solve A' * x = b
  52. * \endcode
  53. *
  54. * \warning Currently, any other usage of transpose() and adjoint() are not supported and will produce compilation errors.
  55. *
  56. * \sa class PartialPivLU, class FullPivLU, class HouseholderQR, class ColPivHouseholderQR, class FullPivHouseholderQR, class CompleteOrthogonalDecomposition, class LLT, class LDLT, class SVDBase
  57. */
  58. template<typename Derived>
  59. class SolverBase : public EigenBase<Derived>
  60. {
  61. public:
  62. typedef EigenBase<Derived> Base;
  63. typedef typename internal::traits<Derived>::Scalar Scalar;
  64. typedef Scalar CoeffReturnType;
  65. template<typename Derived_>
  66. friend struct internal::solve_assertion;
  67. enum {
  68. RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime,
  69. ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime,
  70. SizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::RowsAtCompileTime,
  71. internal::traits<Derived>::ColsAtCompileTime>::ret),
  72. MaxRowsAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime,
  73. MaxColsAtCompileTime = internal::traits<Derived>::MaxColsAtCompileTime,
  74. MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime,
  75. internal::traits<Derived>::MaxColsAtCompileTime>::ret),
  76. IsVectorAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime == 1
  77. || internal::traits<Derived>::MaxColsAtCompileTime == 1,
  78. NumDimensions = int(MaxSizeAtCompileTime) == 1 ? 0 : bool(IsVectorAtCompileTime) ? 1 : 2
  79. };
  80. /** Default constructor */
  81. SolverBase()
  82. {}
  83. ~SolverBase()
  84. {}
  85. using Base::derived;
  86. /** \returns an expression of the solution x of \f$ A x = b \f$ using the current decomposition of A.
  87. */
  88. template<typename Rhs>
  89. inline const Solve<Derived, Rhs>
  90. solve(const MatrixBase<Rhs>& b) const
  91. {
  92. internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<false>(derived(), b);
  93. return Solve<Derived, Rhs>(derived(), b.derived());
  94. }
  95. /** \internal the return type of transpose() */
  96. typedef typename internal::add_const<Transpose<const Derived> >::type ConstTransposeReturnType;
  97. /** \returns an expression of the transposed of the factored matrix.
  98. *
  99. * A typical usage is to solve for the transposed problem A^T x = b:
  100. * \code x = dec.transpose().solve(b); \endcode
  101. *
  102. * \sa adjoint(), solve()
  103. */
  104. inline ConstTransposeReturnType transpose() const
  105. {
  106. return ConstTransposeReturnType(derived());
  107. }
  108. /** \internal the return type of adjoint() */
  109. typedef typename internal::conditional<NumTraits<Scalar>::IsComplex,
  110. CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, ConstTransposeReturnType>,
  111. ConstTransposeReturnType
  112. >::type AdjointReturnType;
  113. /** \returns an expression of the adjoint of the factored matrix
  114. *
  115. * A typical usage is to solve for the adjoint problem A' x = b:
  116. * \code x = dec.adjoint().solve(b); \endcode
  117. *
  118. * For real scalar types, this function is equivalent to transpose().
  119. *
  120. * \sa transpose(), solve()
  121. */
  122. inline AdjointReturnType adjoint() const
  123. {
  124. return AdjointReturnType(derived().transpose());
  125. }
  126. protected:
  127. template<bool Transpose_, typename Rhs>
  128. void _check_solve_assertion(const Rhs& b) const {
  129. EIGEN_ONLY_USED_FOR_DEBUG(b);
  130. eigen_assert(derived().m_isInitialized && "Solver is not initialized.");
  131. eigen_assert((Transpose_?derived().cols():derived().rows())==b.rows() && "SolverBase::solve(): invalid number of rows of the right hand side matrix b");
  132. }
  133. };
  134. namespace internal {
  135. template<typename Derived>
  136. struct generic_xpr_base<Derived, MatrixXpr, SolverStorage>
  137. {
  138. typedef SolverBase<Derived> type;
  139. };
  140. } // end namespace internal
  141. } // end namespace Eigen
  142. #endif // EIGEN_SOLVERBASE_H