fn_spsolve.hpp 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup fn_spsolve
  16. //! @{
  17. //! Solve a system of linear equations, i.e., A*X = B, where X is unknown,
  18. //! A is sparse, and B is dense. X will be dense too.
  19. template<typename T1, typename T2>
  20. inline
  21. bool
  22. spsolve_helper
  23. (
  24. Mat<typename T1::elem_type>& out,
  25. const SpBase<typename T1::elem_type, T1>& A,
  26. const Base<typename T1::elem_type, T2>& B,
  27. const char* solver,
  28. const spsolve_opts_base& settings,
  29. const typename arma_blas_type_only<typename T1::elem_type>::result* junk = 0
  30. )
  31. {
  32. arma_extra_debug_sigprint();
  33. arma_ignore(junk);
  34. typedef typename T1::pod_type T;
  35. typedef typename T1::elem_type eT;
  36. const char sig = (solver != NULL) ? solver[0] : char(0);
  37. arma_debug_check( ((sig != 'l') && (sig != 's')), "spsolve(): unknown solver" );
  38. T rcond = T(0);
  39. bool status = false;
  40. superlu_opts superlu_opts_default;
  41. // if(is_float <T>::value) { superlu_opts_default.refine = superlu_opts::REF_SINGLE; }
  42. // if(is_double<T>::value) { superlu_opts_default.refine = superlu_opts::REF_DOUBLE; }
  43. const superlu_opts& opts = (settings.id == 1) ? static_cast<const superlu_opts&>(settings) : superlu_opts_default;
  44. arma_debug_check( ( (opts.pivot_thresh < double(0)) || (opts.pivot_thresh > double(1)) ), "spsolve(): pivot_thresh out of bounds" );
  45. if(sig == 's') // SuperLU solver
  46. {
  47. if( (opts.equilibrate == false) && (opts.refine == superlu_opts::REF_NONE) )
  48. {
  49. status = sp_auxlib::spsolve_simple(out, A.get_ref(), B.get_ref(), opts);
  50. }
  51. else
  52. {
  53. status = sp_auxlib::spsolve_refine(out, rcond, A.get_ref(), B.get_ref(), opts);
  54. }
  55. }
  56. else
  57. if(sig == 'l') // brutal LAPACK solver
  58. {
  59. if( (settings.id != 0) && ((opts.symmetric) || (opts.pivot_thresh != double(1))) )
  60. {
  61. arma_debug_warn("spsolve(): ignoring settings not applicable to LAPACK based solver");
  62. }
  63. Mat<eT> AA;
  64. bool conversion_ok = false;
  65. try
  66. {
  67. Mat<eT> tmp(A.get_ref()); // conversion from sparse to dense can throw std::bad_alloc
  68. AA.steal_mem(tmp);
  69. conversion_ok = true;
  70. }
  71. catch(std::bad_alloc&)
  72. {
  73. arma_debug_warn("spsolve(): not enough memory to use LAPACK based solver");
  74. }
  75. if(conversion_ok)
  76. {
  77. arma_debug_check( (AA.n_rows != AA.n_cols), "spsolve(): matrix A must be square sized" );
  78. uword flags = solve_opts::flag_none;
  79. if(opts.refine != superlu_opts::REF_NONE) { flags |= solve_opts::flag_refine; }
  80. if(opts.equilibrate == true ) { flags |= solve_opts::flag_equilibrate; }
  81. if(opts.allow_ugly == true ) { flags |= solve_opts::flag_allow_ugly; }
  82. status = glue_solve_gen::apply(out, AA, B.get_ref(), flags);
  83. }
  84. }
  85. if(status == false)
  86. {
  87. if(rcond > T(0)) { arma_debug_warn("spsolve(): system seems singular (rcond: ", rcond, ")"); }
  88. else { arma_debug_warn("spsolve(): system seems singular"); }
  89. out.soft_reset();
  90. }
  91. if( (status == true) && (rcond > T(0)) && (rcond < auxlib::epsilon_lapack(out)) )
  92. {
  93. arma_debug_warn("solve(): solution computed, but system seems singular to working precision (rcond: ", rcond, ")");
  94. }
  95. return status;
  96. }
  97. template<typename T1, typename T2>
  98. inline
  99. bool
  100. spsolve
  101. (
  102. Mat<typename T1::elem_type>& out,
  103. const SpBase<typename T1::elem_type, T1>& A,
  104. const Base<typename T1::elem_type, T2>& B,
  105. const char* solver = "superlu",
  106. const spsolve_opts_base& settings = spsolve_opts_none(),
  107. const typename arma_blas_type_only<typename T1::elem_type>::result* junk = 0
  108. )
  109. {
  110. arma_extra_debug_sigprint();
  111. arma_ignore(junk);
  112. const bool status = spsolve_helper(out, A.get_ref(), B.get_ref(), solver, settings);
  113. return status;
  114. }
  115. template<typename T1, typename T2>
  116. arma_warn_unused
  117. inline
  118. Mat<typename T1::elem_type>
  119. spsolve
  120. (
  121. const SpBase<typename T1::elem_type, T1>& A,
  122. const Base<typename T1::elem_type, T2>& B,
  123. const char* solver = "superlu",
  124. const spsolve_opts_base& settings = spsolve_opts_none(),
  125. const typename arma_blas_type_only<typename T1::elem_type>::result* junk = 0
  126. )
  127. {
  128. arma_extra_debug_sigprint();
  129. arma_ignore(junk);
  130. typedef typename T1::elem_type eT;
  131. Mat<eT> out;
  132. const bool status = spsolve_helper(out, A.get_ref(), B.get_ref(), solver, settings);
  133. if(status == false)
  134. {
  135. arma_stop_runtime_error("spsolve(): solution not found");
  136. }
  137. return out;
  138. }
  139. //! @}