glue_solve_meat.hpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup glue_solve
  16. //! @{
  17. //
  18. // glue_solve_gen
  19. template<typename T1, typename T2>
  20. inline
  21. void
  22. glue_solve_gen::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_solve_gen>& X)
  23. {
  24. arma_extra_debug_sigprint();
  25. const bool status = glue_solve_gen::apply( out, X.A, X.B, X.aux_uword );
  26. if(status == false)
  27. {
  28. arma_stop_runtime_error("solve(): solution not found");
  29. }
  30. }
  31. template<typename eT, typename T1, typename T2>
  32. inline
  33. bool
  34. glue_solve_gen::apply(Mat<eT>& out, const Base<eT,T1>& A_expr, const Base<eT,T2>& B_expr, const uword flags)
  35. {
  36. arma_extra_debug_sigprint();
  37. typedef typename get_pod_type<eT>::result T;
  38. const bool fast = bool(flags & solve_opts::flag_fast );
  39. const bool equilibrate = bool(flags & solve_opts::flag_equilibrate );
  40. const bool no_approx = bool(flags & solve_opts::flag_no_approx );
  41. const bool no_band = bool(flags & solve_opts::flag_no_band );
  42. const bool no_sympd = bool(flags & solve_opts::flag_no_sympd );
  43. const bool allow_ugly = bool(flags & solve_opts::flag_allow_ugly );
  44. const bool likely_sympd = bool(flags & solve_opts::flag_likely_sympd);
  45. const bool refine = bool(flags & solve_opts::flag_refine );
  46. const bool no_trimat = bool(flags & solve_opts::flag_no_trimat );
  47. arma_extra_debug_print("glue_solve_gen::apply(): enabled flags:");
  48. if(fast ) { arma_extra_debug_print("fast"); }
  49. if(equilibrate ) { arma_extra_debug_print("equilibrate"); }
  50. if(no_approx ) { arma_extra_debug_print("no_approx"); }
  51. if(no_band ) { arma_extra_debug_print("no_band"); }
  52. if(no_sympd ) { arma_extra_debug_print("no_sympd"); }
  53. if(allow_ugly ) { arma_extra_debug_print("allow_ugly"); }
  54. if(likely_sympd) { arma_extra_debug_print("likely_sympd"); }
  55. if(refine ) { arma_extra_debug_print("refine"); }
  56. if(no_trimat ) { arma_extra_debug_print("no_trimat"); }
  57. arma_debug_check( (fast && equilibrate ), "solve(): options 'fast' and 'equilibrate' are mutually exclusive" );
  58. arma_debug_check( (fast && refine ), "solve(): options 'fast' and 'refine' are mutually exclusive" );
  59. arma_debug_check( (no_sympd && likely_sympd), "solve(): options 'no_sympd' and 'likely_sympd' are mutually exclusive" );
  60. T rcond = T(0);
  61. bool status = false;
  62. Mat<eT> A = A_expr.get_ref();
  63. if(A.n_rows == A.n_cols)
  64. {
  65. arma_extra_debug_print("glue_solve_gen::apply(): detected square system");
  66. uword KL = 0;
  67. uword KU = 0;
  68. #if defined(ARMA_OPTIMISE_BAND)
  69. const bool is_band = (no_band || auxlib::crippled_lapack(A)) ? false : band_helper::is_band(KL, KU, A, uword(32));
  70. #else
  71. const bool is_band = false;
  72. #endif
  73. const bool is_triu = (no_trimat || refine || equilibrate || likely_sympd || is_band ) ? false : trimat_helper::is_triu(A);
  74. const bool is_tril = (no_trimat || refine || equilibrate || likely_sympd || is_band || is_triu) ? false : trimat_helper::is_tril(A);
  75. #if defined(ARMA_OPTIMISE_SYMPD)
  76. const bool try_sympd = (no_sympd || auxlib::crippled_lapack(A) || is_band || is_triu || is_tril) ? false : (likely_sympd ? true : sympd_helper::guess_sympd(A));
  77. #else
  78. const bool try_sympd = false;
  79. #endif
  80. if(fast)
  81. {
  82. // fast mode: solvers without refinement and without rcond estimate
  83. arma_extra_debug_print("glue_solve_gen::apply(): fast mode");
  84. if(is_band)
  85. {
  86. if( (KL == 1) && (KU == 1) )
  87. {
  88. arma_extra_debug_print("glue_solve_gen::apply(): fast + tridiagonal");
  89. status = auxlib::solve_tridiag_fast(out, A, B_expr.get_ref());
  90. }
  91. else
  92. {
  93. arma_extra_debug_print("glue_solve_gen::apply(): fast + band");
  94. status = auxlib::solve_band_fast(out, A, KL, KU, B_expr.get_ref());
  95. }
  96. }
  97. else
  98. if(is_triu || is_tril)
  99. {
  100. if(is_triu) { arma_extra_debug_print("glue_solve_gen::apply(): fast + upper triangular matrix"); }
  101. if(is_tril) { arma_extra_debug_print("glue_solve_gen::apply(): fast + lower triangular matrix"); }
  102. const uword layout = (is_triu) ? uword(0) : uword(1);
  103. status = auxlib::solve_trimat_fast(out, A, B_expr.get_ref(), layout);
  104. }
  105. else
  106. if(try_sympd)
  107. {
  108. arma_extra_debug_print("glue_solve_gen::apply(): fast + try_sympd");
  109. status = auxlib::solve_sympd_fast(out, A, B_expr.get_ref()); // A is overwritten
  110. if(status == false)
  111. {
  112. arma_extra_debug_print("glue_solve_gen::apply(): auxlib::solve_sympd_fast() failed; retrying");
  113. // auxlib::solve_sympd_fast() may have failed because A isn't really sympd
  114. A = A_expr.get_ref();
  115. status = auxlib::solve_square_fast(out, A, B_expr.get_ref()); // A is overwritten
  116. }
  117. }
  118. else
  119. {
  120. arma_extra_debug_print("glue_solve_gen::apply(): fast + dense");
  121. status = auxlib::solve_square_fast(out, A, B_expr.get_ref()); // A is overwritten
  122. }
  123. }
  124. else
  125. if(refine || equilibrate)
  126. {
  127. // refine mode: solvers with refinement and with rcond estimate
  128. arma_extra_debug_print("glue_solve_gen::apply(): refine mode");
  129. if(is_band)
  130. {
  131. arma_extra_debug_print("glue_solve_gen::apply(): refine + band");
  132. status = auxlib::solve_band_refine(out, rcond, A, KL, KU, B_expr, equilibrate, allow_ugly);
  133. }
  134. else
  135. if(try_sympd)
  136. {
  137. arma_extra_debug_print("glue_solve_gen::apply(): refine + try_sympd");
  138. status = auxlib::solve_sympd_refine(out, rcond, A, B_expr.get_ref(), equilibrate, allow_ugly); // A is overwritten
  139. if(status == false)
  140. {
  141. arma_extra_debug_print("glue_solve_gen::apply(): auxlib::solve_sympd_refine() failed; retrying");
  142. // auxlib::solve_sympd_refine() may have failed because A isn't really sympd
  143. A = A_expr.get_ref();
  144. status = auxlib::solve_square_refine(out, rcond, A, B_expr.get_ref(), equilibrate, allow_ugly); // A is overwritten
  145. }
  146. }
  147. else
  148. {
  149. arma_extra_debug_print("glue_solve_gen::apply(): refine + dense");
  150. status = auxlib::solve_square_refine(out, rcond, A, B_expr, equilibrate, allow_ugly); // A is overwritten
  151. }
  152. }
  153. else
  154. {
  155. // default mode: solvers without refinement but with rcond estimate
  156. arma_extra_debug_print("glue_solve_gen::apply(): default mode");
  157. if(is_band)
  158. {
  159. arma_extra_debug_print("glue_solve_gen::apply(): rcond + band");
  160. status = auxlib::solve_band_rcond(out, rcond, A, KL, KU, B_expr.get_ref(), allow_ugly);
  161. }
  162. else
  163. if(is_triu || is_tril)
  164. {
  165. if(is_triu) { arma_extra_debug_print("glue_solve_gen::apply(): rcond + upper triangular matrix"); }
  166. if(is_tril) { arma_extra_debug_print("glue_solve_gen::apply(): rcond + lower triangular matrix"); }
  167. const uword layout = (is_triu) ? uword(0) : uword(1);
  168. status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout, allow_ugly);
  169. }
  170. else
  171. if(try_sympd)
  172. {
  173. status = auxlib::solve_sympd_rcond(out, rcond, A, B_expr.get_ref(), allow_ugly); // A is overwritten
  174. if(status == false)
  175. {
  176. arma_extra_debug_print("glue_solve_gen::apply(): auxlib::solve_sympd_rcond() failed; retrying");
  177. // auxlib::solve_sympd_rcond() may have failed because A isn't really sympd
  178. A = A_expr.get_ref();
  179. status = auxlib::solve_square_rcond(out, rcond, A, B_expr.get_ref(), allow_ugly); // A is overwritten
  180. }
  181. }
  182. else
  183. {
  184. status = auxlib::solve_square_rcond(out, rcond, A, B_expr.get_ref(), allow_ugly); // A is overwritten
  185. }
  186. }
  187. if( (status == true) && (rcond > T(0)) && (rcond < auxlib::epsilon_lapack(A)) )
  188. {
  189. arma_debug_warn("solve(): solution computed, but system seems singular to working precision (rcond: ", rcond, ")");
  190. }
  191. if( (status == false) && (no_approx == false) )
  192. {
  193. arma_extra_debug_print("glue_solve_gen::apply(): solving rank deficient system");
  194. if(rcond > T(0))
  195. {
  196. arma_debug_warn("solve(): system seems singular (rcond: ", rcond, "); attempting approx solution");
  197. }
  198. else
  199. {
  200. arma_debug_warn("solve(): system seems singular; attempting approx solution");
  201. }
  202. // TODO: conditionally recreate A: have a separate state flag which indicates whether A was previously overwritten
  203. A = A_expr.get_ref(); // as A may have been overwritten
  204. status = auxlib::solve_approx_svd(out, A, B_expr.get_ref()); // A is overwritten
  205. }
  206. }
  207. else
  208. {
  209. arma_extra_debug_print("glue_solve_gen::apply(): detected non-square system");
  210. if(equilibrate) { arma_debug_warn( "solve(): option 'equilibrate' ignored for non-square matrix" ); }
  211. if(refine) { arma_debug_warn( "solve(): option 'refine' ignored for non-square matrix" ); }
  212. if(likely_sympd) { arma_debug_warn( "solve(): option 'likely_sympd' ignored for non-square matrix" ); }
  213. if(fast)
  214. {
  215. status = auxlib::solve_rect_fast(out, A, B_expr.get_ref()); // A is overwritten
  216. }
  217. else
  218. {
  219. status = auxlib::solve_rect_rcond(out, rcond, A, B_expr.get_ref(), allow_ugly); // A is overwritten
  220. }
  221. if( (status == true) && (rcond > T(0)) && (rcond < auxlib::epsilon_lapack(A)) )
  222. {
  223. arma_debug_warn("solve(): solution computed, but system seems singular to working precision (rcond: ", rcond, ")");
  224. }
  225. if( (status == false) && (no_approx == false) )
  226. {
  227. arma_extra_debug_print("glue_solve_gen::apply(): solving rank deficient system");
  228. if(rcond > T(0))
  229. {
  230. arma_debug_warn("solve(): system seems singular (rcond: ", rcond, "); attempting approx solution");
  231. }
  232. else
  233. {
  234. arma_debug_warn("solve(): system seems singular; attempting approx solution");
  235. }
  236. A = A_expr.get_ref(); // as A was overwritten
  237. status = auxlib::solve_approx_svd(out, A, B_expr.get_ref()); // A is overwritten
  238. }
  239. }
  240. if(status == false) { out.soft_reset(); }
  241. return status;
  242. }
  243. //
  244. // glue_solve_tri
  245. template<typename T1, typename T2>
  246. inline
  247. void
  248. glue_solve_tri_default::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_solve_tri_default>& X)
  249. {
  250. arma_extra_debug_sigprint();
  251. const bool status = glue_solve_tri_default::apply( out, X.A, X.B, X.aux_uword );
  252. if(status == false)
  253. {
  254. arma_stop_runtime_error("solve(): solution not found");
  255. }
  256. }
  257. template<typename eT, typename T1, typename T2>
  258. inline
  259. bool
  260. glue_solve_tri_default::apply(Mat<eT>& actual_out, const Base<eT,T1>& A_expr, const Base<eT,T2>& B_expr, const uword flags)
  261. {
  262. arma_extra_debug_sigprint();
  263. typedef typename get_pod_type<eT>::result T;
  264. const bool triu = bool(flags & solve_opts::flag_triu);
  265. const bool tril = bool(flags & solve_opts::flag_tril);
  266. const bool allow_ugly = false;
  267. arma_extra_debug_print("glue_solve_tri_default::apply(): enabled flags:");
  268. if(triu) { arma_extra_debug_print("triu"); }
  269. if(tril) { arma_extra_debug_print("tril"); }
  270. const quasi_unwrap<T1> U(A_expr.get_ref());
  271. const Mat<eT>& A = U.M;
  272. arma_debug_check( (A.is_square() == false), "solve(): matrix marked as triangular must be square sized" );
  273. const uword layout = (triu) ? uword(0) : uword(1);
  274. const bool is_alias = U.is_alias(actual_out);
  275. T rcond = T(0);
  276. bool status = false;
  277. Mat<eT> tmp;
  278. Mat<eT>& out = (is_alias) ? tmp : actual_out;
  279. status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout, allow_ugly); // A is not modified
  280. if( (status == true) && (rcond > T(0)) && (rcond < auxlib::epsilon_lapack(A)) )
  281. {
  282. arma_debug_warn("solve(): solution computed, but system seems singular to working precision (rcond: ", rcond, ")");
  283. }
  284. if(status == false)
  285. {
  286. arma_extra_debug_print("glue_solve_tri::apply(): solving rank deficient system");
  287. if(rcond > T(0))
  288. {
  289. arma_debug_warn("solve(): system seems singular (rcond: ", rcond, "); attempting approx solution");
  290. }
  291. else
  292. {
  293. arma_debug_warn("solve(): system seems singular; attempting approx solution");
  294. }
  295. Mat<eT> triA = (triu) ? trimatu(A) : trimatl(A); // trimatu() and trimatl() return the same type
  296. status = auxlib::solve_approx_svd(out, triA, B_expr.get_ref()); // triA is overwritten
  297. }
  298. if(status == false) { out.soft_reset(); }
  299. if(is_alias) { actual_out.steal_mem(out); }
  300. return status;
  301. }
  302. template<typename T1, typename T2>
  303. inline
  304. void
  305. glue_solve_tri::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_solve_tri>& X)
  306. {
  307. arma_extra_debug_sigprint();
  308. const bool status = glue_solve_tri::apply( out, X.A, X.B, X.aux_uword );
  309. if(status == false)
  310. {
  311. arma_stop_runtime_error("solve(): solution not found");
  312. }
  313. }
  314. template<typename eT, typename T1, typename T2>
  315. inline
  316. bool
  317. glue_solve_tri::apply(Mat<eT>& actual_out, const Base<eT,T1>& A_expr, const Base<eT,T2>& B_expr, const uword flags)
  318. {
  319. arma_extra_debug_sigprint();
  320. typedef typename get_pod_type<eT>::result T;
  321. const bool fast = bool(flags & solve_opts::flag_fast );
  322. const bool equilibrate = bool(flags & solve_opts::flag_equilibrate );
  323. const bool no_approx = bool(flags & solve_opts::flag_no_approx );
  324. const bool triu = bool(flags & solve_opts::flag_triu );
  325. const bool tril = bool(flags & solve_opts::flag_tril );
  326. const bool allow_ugly = bool(flags & solve_opts::flag_allow_ugly );
  327. const bool likely_sympd = bool(flags & solve_opts::flag_likely_sympd);
  328. const bool refine = bool(flags & solve_opts::flag_refine );
  329. const bool no_trimat = bool(flags & solve_opts::flag_no_trimat );
  330. arma_extra_debug_print("glue_solve_tri::apply(): enabled flags:");
  331. if(fast ) { arma_extra_debug_print("fast"); }
  332. if(equilibrate ) { arma_extra_debug_print("equilibrate"); }
  333. if(no_approx ) { arma_extra_debug_print("no_approx"); }
  334. if(triu ) { arma_extra_debug_print("triu"); }
  335. if(tril ) { arma_extra_debug_print("tril"); }
  336. if(allow_ugly ) { arma_extra_debug_print("allow_ugly"); }
  337. if(likely_sympd) { arma_extra_debug_print("likely_sympd"); }
  338. if(refine ) { arma_extra_debug_print("refine"); }
  339. if(no_trimat ) { arma_extra_debug_print("no_trimat"); }
  340. if(no_trimat || equilibrate || refine)
  341. {
  342. const uword mask = ~(solve_opts::flag_triu | solve_opts::flag_tril);
  343. return glue_solve_gen::apply(actual_out, ((triu) ? trimatu(A_expr.get_ref()) : trimatl(A_expr.get_ref())), B_expr, (flags & mask));
  344. }
  345. if(likely_sympd) { arma_debug_warn("solve(): option 'likely_sympd' ignored for triangular matrix"); }
  346. const quasi_unwrap<T1> U(A_expr.get_ref());
  347. const Mat<eT>& A = U.M;
  348. arma_debug_check( (A.is_square() == false), "solve(): matrix marked as triangular must be square sized" );
  349. const uword layout = (triu) ? uword(0) : uword(1);
  350. const bool is_alias = U.is_alias(actual_out);
  351. T rcond = T(0);
  352. bool status = false;
  353. Mat<eT> tmp;
  354. Mat<eT>& out = (is_alias) ? tmp : actual_out;
  355. if(fast)
  356. {
  357. status = auxlib::solve_trimat_fast(out, A, B_expr.get_ref(), layout); // A is not modified
  358. }
  359. else
  360. {
  361. status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout, allow_ugly); // A is not modified
  362. }
  363. if( (status == true) && (rcond > T(0)) && (rcond < auxlib::epsilon_lapack(A)) )
  364. {
  365. arma_debug_warn("solve(): solution computed, but system seems singular to working precision (rcond: ", rcond, ")");
  366. }
  367. if( (status == false) && (no_approx == false) )
  368. {
  369. arma_extra_debug_print("glue_solve_tri::apply(): solving rank deficient system");
  370. if(rcond > T(0))
  371. {
  372. arma_debug_warn("solve(): system seems singular (rcond: ", rcond, "); attempting approx solution");
  373. }
  374. else
  375. {
  376. arma_debug_warn("solve(): system seems singular; attempting approx solution");
  377. }
  378. Mat<eT> triA = (triu) ? trimatu(A) : trimatl(A); // trimatu() and trimatl() return the same type
  379. status = auxlib::solve_approx_svd(out, triA, B_expr.get_ref()); // triA is overwritten
  380. }
  381. if(status == false) { out.soft_reset(); }
  382. if(is_alias) { actual_out.steal_mem(out); }
  383. return status;
  384. }
  385. //! @}