fn_spsolve.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900
  1. // Copyright 2011-2017 Ryan Curtin (http://www.ratml.org/)
  2. // Copyright 2017 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. #include <armadillo>
  16. #include "catch.hpp"
  17. using namespace arma;
  18. #if defined(ARMA_USE_SUPERLU)
  19. TEST_CASE("fn_spsolve_sparse_test")
  20. {
  21. // We want to spsolve a system of equations, AX = B, where we want to recover
  22. // X and we have A and B, and A is sparse.
  23. for (size_t t = 0; t < 10; ++t)
  24. {
  25. const uword size = 5 * (t + 1);
  26. mat rX;
  27. rX.randu(size, size);
  28. sp_mat A;
  29. A.sprandu(size, size, 0.25);
  30. for (uword i = 0; i < size; ++i)
  31. {
  32. A(i, i) += rand();
  33. }
  34. mat B = A * rX;
  35. mat X;
  36. bool result = spsolve(X, A, B);
  37. REQUIRE( result );
  38. // Dense solver.
  39. mat dA(A);
  40. mat dX = solve(dA, B);
  41. REQUIRE( X.n_cols == dX.n_cols );
  42. REQUIRE( X.n_rows == dX.n_rows );
  43. for (uword i = 0; i < dX.n_cols; ++i)
  44. {
  45. for (uword j = 0; j < dX.n_rows; ++j)
  46. {
  47. REQUIRE( (double) X(j, i) == Approx((double) dX(j, i)) );
  48. }
  49. }
  50. }
  51. }
  52. TEST_CASE("fn_spsolve_sparse_nonsymmetric_test")
  53. {
  54. for (size_t t = 0; t < 10; ++t)
  55. {
  56. const uword r_size = 5 * (t + 1);
  57. const uword c_size = 3 * (t + 4);
  58. mat rX;
  59. rX.randu(r_size, c_size);
  60. sp_mat A;
  61. A.sprandu(r_size, r_size, 0.25);
  62. for (uword i = 0; i < r_size; ++i)
  63. {
  64. A(i, i) += rand();
  65. }
  66. mat B = A * rX;
  67. mat X;
  68. bool result = spsolve(X, A, B);
  69. REQUIRE( result );
  70. // Dense solver.
  71. mat dA(A);
  72. mat dX = solve(dA, B);
  73. REQUIRE( X.n_cols == dX.n_cols );
  74. REQUIRE( X.n_rows == dX.n_rows );
  75. for (uword i = 0; i < dX.n_cols; ++i)
  76. {
  77. for (uword j = 0; j < dX.n_rows; ++j)
  78. {
  79. REQUIRE( (double) X(j, i) == Approx((double) dX(j, i)) );
  80. }
  81. }
  82. }
  83. }
  84. TEST_CASE("fn_spsolve_sparse_float_test")
  85. {
  86. // We want to spsolve a system of equations, AX = B, where we want to recover
  87. // X and we have A and B, and A is sparse.
  88. for (size_t t = 0; t < 10; ++t)
  89. {
  90. const uword size = 5 * (t + 1);
  91. fmat rX;
  92. rX.randu(size, size);
  93. SpMat<float> A;
  94. A.sprandu(size, size, 0.25);
  95. for (uword i = 0; i < size; ++i)
  96. {
  97. A(i, i) += rand();
  98. }
  99. fmat B = A * rX;
  100. fmat X;
  101. bool result = spsolve(X, A, B);
  102. REQUIRE( result );
  103. // Dense solver.
  104. fmat dA(A);
  105. fmat dX = solve(dA, B);
  106. REQUIRE( X.n_cols == dX.n_cols );
  107. REQUIRE( X.n_rows == dX.n_rows );
  108. for (size_t i = 0; i < dX.n_cols; ++i)
  109. {
  110. for (size_t j = 0; j < dX.n_rows; ++j)
  111. {
  112. REQUIRE( (float) X(j, i) == Approx((float) dX(j, i)) );
  113. }
  114. }
  115. }
  116. }
  117. TEST_CASE("fn_spsolve_sparse_nonsymmetric_float_test")
  118. {
  119. for (size_t t = 0; t < 10; ++t)
  120. {
  121. const uword r_size = 5 * (t + 1);
  122. const uword c_size = 3 * (t + 4);
  123. fmat rX;
  124. rX.randu(r_size, c_size);
  125. SpMat<float> A;
  126. A.sprandu(r_size, r_size, 0.25);
  127. for (uword i = 0; i < r_size; ++i)
  128. {
  129. A(i, i) += rand();
  130. }
  131. fmat B = A * rX;
  132. fmat X;
  133. bool result = spsolve(X, A, B);
  134. REQUIRE( result );
  135. // Dense solver.
  136. fmat dA(A);
  137. fmat dX = solve(dA, B);
  138. REQUIRE( X.n_cols == dX.n_cols );
  139. REQUIRE( X.n_rows == dX.n_rows );
  140. for (uword i = 0; i < dX.n_cols; ++i)
  141. {
  142. for (uword j = 0; j < dX.n_rows; ++j)
  143. {
  144. REQUIRE( (float) X(j, i) == Approx((float) dX(j, i)) );
  145. }
  146. }
  147. }
  148. }
  149. TEST_CASE("fn_spsolve_sparse_complex_float_test")
  150. {
  151. // We want to spsolve a system of equations, AX = B, where we want to recover
  152. // X and we have A and B, and A is sparse.
  153. for (size_t t = 0; t < 10; ++t)
  154. {
  155. const uword size = 5 * (t + 1);
  156. Mat<std::complex<float> > rX;
  157. rX.randu(size, size);
  158. SpMat<std::complex<float> > A;
  159. A.sprandu(size, size, 0.25);
  160. for(uword i = 0; i < size; ++i)
  161. {
  162. A(i, i) += rand();
  163. }
  164. Mat<std::complex<float> > B = A * rX;
  165. Mat<std::complex<float> > X;
  166. bool result = spsolve(X, A, B);
  167. REQUIRE( result );
  168. // Dense solver.
  169. Mat<std::complex<float> > dA(A);
  170. Mat<std::complex<float> > dX = solve(dA, B);
  171. REQUIRE( X.n_cols == dX.n_cols );
  172. REQUIRE( X.n_rows == dX.n_rows );
  173. for (uword i = 0; i < dX.n_cols; ++i)
  174. {
  175. for (uword j = 0; j < dX.n_rows; ++j)
  176. {
  177. REQUIRE( (float) std::abs((std::complex<float>) X(j, i)) ==
  178. Approx((float) std::abs((std::complex<float>) dX(j, i))) );
  179. }
  180. }
  181. }
  182. }
  183. TEST_CASE("fn_spsolve_sparse_nonsymmetric_complex_float_test")
  184. {
  185. for (size_t t = 0; t < 10; ++t)
  186. {
  187. const uword r_size = 5 * (t + 1);
  188. const uword c_size = 3 * (t + 4);
  189. Mat<std::complex<float> > rX;
  190. rX.randu(r_size, c_size);
  191. SpMat<std::complex<float> > A;
  192. A.sprandu(r_size, r_size, 0.25);
  193. for (uword i = 0; i < r_size; ++i)
  194. {
  195. A(i, i) += rand();
  196. }
  197. Mat<std::complex<float> > B = A * rX;
  198. Mat<std::complex<float> > X;
  199. bool result = spsolve(X, A, B);
  200. REQUIRE( result );
  201. // Dense solver.
  202. Mat<std::complex<float> > dA(A);
  203. Mat<std::complex<float> > dX = solve(dA, B);
  204. REQUIRE( X.n_cols == dX.n_cols );
  205. REQUIRE( X.n_rows == dX.n_rows );
  206. for (uword i = 0; i < dX.n_cols; ++i)
  207. {
  208. for (uword j = 0; j < dX.n_rows; ++j)
  209. {
  210. REQUIRE( (float) std::abs((std::complex<float>) X(j, i)) ==
  211. Approx((float) std::abs((std::complex<float>) dX(j, i))) );
  212. }
  213. }
  214. }
  215. }
  216. TEST_CASE("fn_spsolve_sparse_complex_test")
  217. {
  218. // We want to spsolve a system of equations, AX = B, where we want to recover
  219. // X and we have A and B, and A is sparse.
  220. for (size_t t = 0; t < 10; ++t)
  221. {
  222. const uword size = 5 * (t + 1);
  223. Mat<std::complex<double> > rX;
  224. rX.randu(size, size);
  225. SpMat<std::complex<double> > A;
  226. A.sprandu(size, size, 0.25);
  227. for (uword i = 0; i < size; ++i)
  228. {
  229. A(i, i) += rand();
  230. }
  231. Mat<std::complex<double> > B = A * rX;
  232. Mat<std::complex<double> > X;
  233. bool result = spsolve(X, A, B);
  234. REQUIRE( result );
  235. // Dense solver.
  236. Mat<std::complex<double> > dA(A);
  237. Mat<std::complex<double> > dX = solve(dA, B);
  238. REQUIRE( X.n_cols == dX.n_cols );
  239. REQUIRE( X.n_rows == dX.n_rows );
  240. for (uword i = 0; i < dX.n_cols; ++i)
  241. {
  242. for (uword j = 0; j < dX.n_rows; ++j)
  243. {
  244. REQUIRE( (double) std::abs((std::complex<double>) X(j, i)) ==
  245. Approx((double) std::abs((std::complex<double>) dX(j, i))) );
  246. }
  247. }
  248. }
  249. }
  250. TEST_CASE("fn_spsolve_sparse_nonsymmetric_complex_test")
  251. {
  252. for (size_t t = 0; t < 10; ++t)
  253. {
  254. const uword r_size = 5 * (t + 1);
  255. const uword c_size = 3 * (t + 4);
  256. Mat<std::complex<double> > rX;
  257. rX.randu(r_size, c_size);
  258. SpMat<std::complex<double> > A;
  259. A.sprandu(r_size, r_size, 0.25);
  260. for (uword i = 0; i < r_size; ++i)
  261. {
  262. A(i, i) += rand();
  263. }
  264. Mat<std::complex<double> > B = A * rX;
  265. Mat<std::complex<double> > X;
  266. bool result = spsolve(X, A, B);
  267. REQUIRE( result );
  268. // Dense solver.
  269. Mat<std::complex<double> > dA(A);
  270. Mat<std::complex<double> > dX = solve(dA, B);
  271. REQUIRE( X.n_cols == dX.n_cols );
  272. REQUIRE( X.n_rows == dX.n_rows );
  273. for (uword i = 0; i < dX.n_cols; ++i)
  274. {
  275. for (uword j = 0; j < dX.n_rows; ++j)
  276. {
  277. REQUIRE( (double) std::abs((std::complex<double>) X(j, i)) ==
  278. Approx((double) std::abs((std::complex<double>) dX(j, i))) );
  279. }
  280. }
  281. }
  282. }
  283. TEST_CASE("fn_spsolve_delayed_sparse_test")
  284. {
  285. const uword size = 10;
  286. mat rX;
  287. rX.randu(size, size);
  288. sp_mat A;
  289. A.sprandu(size, size, 0.25);
  290. for (uword i = 0; i < size; ++i)
  291. {
  292. A(i, i) += rand();
  293. }
  294. mat B = A * rX;
  295. mat X;
  296. bool result = spsolve(X, A, B);
  297. REQUIRE( result );
  298. mat dX = spsolve(A, B);
  299. REQUIRE( X.n_cols == dX.n_cols );
  300. REQUIRE( X.n_rows == dX.n_rows );
  301. for (uword i = 0; i < dX.n_cols; ++i)
  302. {
  303. for (uword j = 0; j < dX.n_rows; ++j)
  304. {
  305. REQUIRE( (double) X(j, i) == Approx((double) dX(j, i)) );
  306. }
  307. }
  308. }
  309. TEST_CASE("fn_spsolve_superlu_solve_test")
  310. {
  311. // Solve this matrix, as in the examples:
  312. // [[19 0 21 21 0]
  313. // [12 21 0 0 0]
  314. // [ 0 12 16 0 0]
  315. // [ 0 0 0 5 21]
  316. // [12 12 0 0 18]]
  317. sp_mat b(5, 5);
  318. b(0, 0) = 19;
  319. b(0, 2) = 21;
  320. b(0, 3) = 21;
  321. b(1, 0) = 12;
  322. b(1, 1) = 21;
  323. b(2, 1) = 12;
  324. b(2, 2) = 16;
  325. b(3, 3) = 5;
  326. b(3, 4) = 21;
  327. b(4, 0) = 12;
  328. b(4, 1) = 12;
  329. b(4, 4) = 18;
  330. mat db(b);
  331. sp_mat a;
  332. a.eye(5, 5);
  333. mat da(a);
  334. mat x;
  335. spsolve(x, a, db);
  336. mat dx = solve(da, db);
  337. for (uword i = 0; i < x.n_cols; ++i)
  338. {
  339. for (uword j = 0; j < x.n_rows; ++j)
  340. {
  341. REQUIRE( (double) x(j, i) == Approx(dx(j, i)) );
  342. }
  343. }
  344. }
  345. TEST_CASE("fn_spsolve_random_superlu_solve_test")
  346. {
  347. // Try to solve some random systems.
  348. const size_t iterations = 10;
  349. for (size_t it = 0; it < iterations; ++it)
  350. {
  351. sp_mat a;
  352. a.sprandu(50, 50, 0.3);
  353. sp_mat trueX;
  354. trueX.sprandu(50, 50, 0.3);
  355. sp_mat b = a * trueX;
  356. // Get things into the right format.
  357. mat db(b);
  358. mat x;
  359. spsolve(x, a, db);
  360. for (uword i = 0; i < x.n_cols; ++i)
  361. {
  362. for (uword j = 0; j < x.n_rows; ++j)
  363. {
  364. REQUIRE( x(j, i) == Approx((double) trueX(j, i)) );
  365. }
  366. }
  367. }
  368. }
  369. TEST_CASE("fn_spsolve_float_superlu_solve_test")
  370. {
  371. // Solve this matrix, as in the examples:
  372. // [[19 0 21 21 0]
  373. // [12 21 0 0 0]
  374. // [ 0 12 16 0 0]
  375. // [ 0 0 0 5 21]
  376. // [12 12 0 0 18]]
  377. sp_fmat b(5, 5);
  378. b(0, 0) = 19;
  379. b(0, 2) = 21;
  380. b(0, 3) = 21;
  381. b(1, 0) = 12;
  382. b(1, 1) = 21;
  383. b(2, 1) = 12;
  384. b(2, 2) = 16;
  385. b(3, 3) = 5;
  386. b(3, 4) = 21;
  387. b(4, 0) = 12;
  388. b(4, 1) = 12;
  389. b(4, 4) = 18;
  390. fmat db(b);
  391. sp_fmat a;
  392. a.eye(5, 5);
  393. fmat da(a);
  394. fmat x;
  395. spsolve(x, a, db);
  396. fmat dx = solve(da, db);
  397. for (uword i = 0; i < x.n_cols; ++i)
  398. {
  399. for (uword j = 0; j < x.n_rows; ++j)
  400. {
  401. REQUIRE( (float) x(j, i) == Approx(dx(j, i)) );
  402. }
  403. }
  404. }
  405. TEST_CASE("fn_spsolve_float_random_superlu_solve_test")
  406. {
  407. // Try to solve some random systems.
  408. const size_t iterations = 10;
  409. for (size_t it = 0; it < iterations; ++it)
  410. {
  411. sp_fmat a;
  412. a.sprandu(50, 50, 0.3);
  413. sp_fmat trueX;
  414. trueX.sprandu(50, 50, 0.3);
  415. sp_fmat b = a * trueX;
  416. // Get things into the right format.
  417. fmat db(b);
  418. fmat x;
  419. spsolve(x, a, db);
  420. for (uword i = 0; i < x.n_cols; ++i)
  421. {
  422. for (uword j = 0; j < x.n_rows; ++j)
  423. {
  424. if (std::abs(trueX(j, i)) < 0.001)
  425. REQUIRE( std::abs(x(j, i)) < 0.005 );
  426. else
  427. REQUIRE( trueX(j, i) == Approx((float) x(j, i)).epsilon(0.01) );
  428. }
  429. }
  430. }
  431. }
  432. TEST_CASE("fn_spsolve_cx_float_superlu_solve_test")
  433. {
  434. // Solve this matrix, as in the examples:
  435. // [[19 0 21 21 0]
  436. // [12 21 0 0 0]
  437. // [ 0 12 16 0 0]
  438. // [ 0 0 0 5 21]
  439. // [12 12 0 0 18]] (imaginary part is the same)
  440. SpMat<std::complex<float> > b(5, 5);
  441. b(0, 0) = std::complex<float>(19, 19);
  442. b(0, 2) = std::complex<float>(21, 21);
  443. b(0, 3) = std::complex<float>(21, 21);
  444. b(1, 0) = std::complex<float>(12, 12);
  445. b(1, 1) = std::complex<float>(21, 21);
  446. b(2, 1) = std::complex<float>(12, 12);
  447. b(2, 2) = std::complex<float>(16, 16);
  448. b(3, 3) = std::complex<float>(5, 5);
  449. b(3, 4) = std::complex<float>(21, 21);
  450. b(4, 0) = std::complex<float>(12, 12);
  451. b(4, 1) = std::complex<float>(12, 12);
  452. b(4, 4) = std::complex<float>(18, 18);
  453. Mat<std::complex<float> > db(b);
  454. SpMat<std::complex<float> > a;
  455. a.eye(5, 5);
  456. Mat<std::complex<float> > da(a);
  457. Mat<std::complex<float> > x;
  458. spsolve(x, a, db);
  459. Mat<std::complex<float> > dx = solve(da, db);
  460. for (uword i = 0; i < x.n_cols; ++i)
  461. {
  462. for (uword j = 0; j < x.n_rows; ++j)
  463. {
  464. if (std::abs(x(j, i)) < 0.001 )
  465. {
  466. REQUIRE( std::abs(dx(j, i)) < 0.005 );
  467. }
  468. else
  469. {
  470. REQUIRE( ((std::complex<float>) x(j, i)).real() ==
  471. Approx(dx(j, i).real()).epsilon(0.01) );
  472. REQUIRE( ((std::complex<float>) x(j, i)).imag() ==
  473. Approx(dx(j, i).imag()).epsilon(0.01) );
  474. }
  475. }
  476. }
  477. }
  478. TEST_CASE("fn_spsolve_cx_float_random_superlu_solve_test")
  479. {
  480. // Try to solve some random systems.
  481. const size_t iterations = 10;
  482. for (size_t it = 0; it < iterations; ++it)
  483. {
  484. SpMat<std::complex<float> > a;
  485. a.sprandu(50, 50, 0.3);
  486. SpMat<std::complex<float> > trueX;
  487. trueX.sprandu(50, 50, 0.3);
  488. SpMat<std::complex<float> > b = a * trueX;
  489. // Get things into the right format.
  490. Mat<std::complex<float> > db(b);
  491. Mat<std::complex<float> > x;
  492. spsolve(x, a, db);
  493. for (uword i = 0; i < x.n_cols; ++i)
  494. {
  495. for (uword j = 0; j < x.n_rows; ++j)
  496. {
  497. if (std::abs((std::complex<float>) trueX(j, i)) < 0.001 )
  498. {
  499. REQUIRE( std::abs(x(j, i)) < 0.001 );
  500. }
  501. else
  502. {
  503. REQUIRE( ((std::complex<float>) trueX(j, i)).real() ==
  504. Approx(x(j, i).real()).epsilon(0.01) );
  505. REQUIRE( ((std::complex<float>) trueX(j, i)).imag() ==
  506. Approx(x(j, i).imag()).epsilon(0.01) );
  507. }
  508. }
  509. }
  510. }
  511. }
  512. TEST_CASE("fn_spsolve_cx_superlu_solve_test")
  513. {
  514. // Solve this matrix, as in the examples:
  515. // [[19 0 21 21 0]
  516. // [12 21 0 0 0]
  517. // [ 0 12 16 0 0]
  518. // [ 0 0 0 5 21]
  519. // [12 12 0 0 18]] (imaginary part is the same)
  520. SpMat<std::complex<double> > b(5, 5);
  521. b(0, 0) = std::complex<double>(19, 19);
  522. b(0, 2) = std::complex<double>(21, 21);
  523. b(0, 3) = std::complex<double>(21, 21);
  524. b(1, 0) = std::complex<double>(12, 12);
  525. b(1, 1) = std::complex<double>(21, 21);
  526. b(2, 1) = std::complex<double>(12, 12);
  527. b(2, 2) = std::complex<double>(16, 16);
  528. b(3, 3) = std::complex<double>(5, 5);
  529. b(3, 4) = std::complex<double>(21, 21);
  530. b(4, 0) = std::complex<double>(12, 12);
  531. b(4, 1) = std::complex<double>(12, 12);
  532. b(4, 4) = std::complex<double>(18, 18);
  533. cx_mat db(b);
  534. sp_cx_mat a;
  535. a.eye(5, 5);
  536. cx_mat da(a);
  537. cx_mat x;
  538. spsolve(x, a, db);
  539. cx_mat dx = solve(da, db);
  540. for (uword i = 0; i < x.n_cols; ++i)
  541. {
  542. for (uword j = 0; j < x.n_rows; ++j)
  543. {
  544. if (std::abs(x(j, i)) < 0.001)
  545. {
  546. REQUIRE( std::abs(dx(j, i)) < 0.005 );
  547. }
  548. else
  549. {
  550. REQUIRE( ((std::complex<double>) x(j, i)).real() ==
  551. Approx(dx(j, i).real()).epsilon(0.01) );
  552. REQUIRE( ((std::complex<double>) x(j, i)).imag() ==
  553. Approx(dx(j, i).imag()).epsilon(0.01) );
  554. }
  555. }
  556. }
  557. }
  558. TEST_CASE("fn_spsolve_cx_random_superlu_solve_test")
  559. {
  560. // Try to solve some random systems.
  561. const size_t iterations = 10;
  562. for (size_t it = 0; it < iterations; ++it)
  563. {
  564. sp_cx_mat a;
  565. a.sprandu(50, 50, 0.3);
  566. sp_cx_mat trueX;
  567. trueX.sprandu(50, 50, 0.3);
  568. sp_cx_mat b = a * trueX;
  569. // Get things into the right format.
  570. cx_mat db(b);
  571. cx_mat x;
  572. spsolve(x, a, db);
  573. for (uword i = 0; i < x.n_cols; ++i)
  574. {
  575. for (uword j = 0; j < x.n_rows; ++j)
  576. {
  577. if (std::abs((std::complex<double>) trueX(j, i)) < 0.001)
  578. {
  579. REQUIRE( std::abs(x(j, i)) < 0.005 );
  580. }
  581. else
  582. {
  583. REQUIRE( ((std::complex<double>) trueX(j, i)).real() ==
  584. Approx(x(j, i).real()).epsilon(0.01) );
  585. REQUIRE( ((std::complex<double>) trueX(j, i)).imag() ==
  586. Approx(x(j, i).imag()).epsilon(0.01) );
  587. }
  588. }
  589. }
  590. }
  591. }
  592. TEST_CASE("fn_spsolve_function_test")
  593. {
  594. sp_mat a;
  595. a.sprandu(50, 50, 0.3);
  596. sp_mat trueX;
  597. trueX.sprandu(50, 50, 0.3);
  598. sp_mat b = a * trueX;
  599. // Get things into the right format.
  600. mat db(b);
  601. mat x;
  602. // Mostly these are compilation tests.
  603. spsolve(x, a, db);
  604. x = spsolve(a, db); // Test another overload.
  605. x = spsolve(a, db + 0.0);
  606. spsolve(x, a, db + 0.0);
  607. for (uword i = 0; i < x.n_cols; ++i)
  608. {
  609. for (uword j = 0; j < x.n_rows; ++j)
  610. {
  611. REQUIRE( (double) trueX(j, i) == Approx(x(j, i)) );
  612. }
  613. }
  614. }
  615. TEST_CASE("fn_spsolve_float_function_test")
  616. {
  617. sp_fmat a;
  618. a.sprandu(50, 50, 0.3);
  619. sp_fmat trueX;
  620. trueX.sprandu(50, 50, 0.3);
  621. sp_fmat b = a * trueX;
  622. // Get things into the right format.
  623. fmat db(b);
  624. fmat x;
  625. // Mostly these are compilation tests.
  626. spsolve(x, a, db);
  627. x = spsolve(a, db); // Test another overload.
  628. x = spsolve(a, db + 0.0);
  629. spsolve(x, a, db + 0.0);
  630. for (uword i = 0; i < x.n_cols; ++i)
  631. {
  632. for (uword j = 0; j < x.n_rows; ++j)
  633. {
  634. if (std::abs(trueX(j, i)) < 0.001)
  635. {
  636. REQUIRE( std::abs(x(j, i)) < 0.001 );
  637. }
  638. else
  639. {
  640. REQUIRE( (float) trueX(j, i) == Approx(x(j, i)).epsilon(0.01) );
  641. }
  642. }
  643. }
  644. }
  645. TEST_CASE("fn_spsolve_cx_function_test")
  646. {
  647. sp_cx_mat a;
  648. a.sprandu(50, 50, 0.3);
  649. sp_cx_mat trueX;
  650. trueX.sprandu(50, 50, 0.3);
  651. sp_cx_mat b = a * trueX;
  652. // Get things into the right format.
  653. cx_mat db(b);
  654. cx_mat x;
  655. // Mostly these are compilation tests.
  656. spsolve(x, a, db);
  657. x = spsolve(a, db); // Test another overload.
  658. x = spsolve(a, db + std::complex<double>(0.0));
  659. spsolve(x, a, db + std::complex<double>(0.0));
  660. for (uword i = 0; i < x.n_cols; ++i)
  661. {
  662. for (uword j = 0; j < x.n_rows; ++j)
  663. {
  664. if (std::abs((std::complex<double>) trueX(j, i)) < 0.001)
  665. {
  666. REQUIRE( std::abs(x(j, i)) < 0.005 );
  667. }
  668. else
  669. {
  670. REQUIRE( ((std::complex<double>) trueX(j, i)).real() ==
  671. Approx(x(j, i).real()).epsilon(0.01) );
  672. REQUIRE( ((std::complex<double>) trueX(j, i)).imag() ==
  673. Approx(x(j, i).imag()).epsilon(0.01) );
  674. }
  675. }
  676. }
  677. }
  678. TEST_CASE("fn_spsolve_cx_float_function_test")
  679. {
  680. sp_cx_fmat a;
  681. a.sprandu(50, 50, 0.3);
  682. sp_cx_fmat trueX;
  683. trueX.sprandu(50, 50, 0.3);
  684. sp_cx_fmat b = a * trueX;
  685. // Get things into the right format.
  686. cx_fmat db(b);
  687. cx_fmat x;
  688. // Mostly these are compilation tests.
  689. spsolve(x, a, db);
  690. x = spsolve(a, db); // Test another overload.
  691. x = spsolve(a, db + std::complex<float>(0.0));
  692. spsolve(x, a, db + std::complex<float>(0.0));
  693. for (uword i = 0; i < x.n_cols; ++i)
  694. {
  695. for (uword j = 0; j < x.n_rows; ++j)
  696. {
  697. if (std::abs((std::complex<float>) trueX(j, i)) < 0.001 )
  698. {
  699. REQUIRE( std::abs(x(j, i)) < 0.005 );
  700. }
  701. else
  702. {
  703. REQUIRE( ((std::complex<float>) trueX(j, i)).real() ==
  704. Approx(x(j, i).real()).epsilon(0.01) );
  705. REQUIRE( ((std::complex<float>) trueX(j, i)).imag() ==
  706. Approx(x(j, i).imag()).epsilon(0.01) );
  707. }
  708. }
  709. }
  710. }
  711. #endif