fn_var.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. // Copyright 2011-2017 Ryan Curtin (http://www.ratml.org/)
  2. // Copyright 2017 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. #include <armadillo>
  16. #include "catch.hpp"
  17. using namespace arma;
  18. TEST_CASE("fn_var_empty_sparse_test")
  19. {
  20. SpMat<double> m(100, 100);
  21. SpRow<double> result = var(m);
  22. REQUIRE( result.n_cols == 100 );
  23. REQUIRE( result.n_rows == 1 );
  24. for (uword i = 0; i < 100; ++i)
  25. {
  26. REQUIRE( (double) result[i] == Approx(0.0) );
  27. }
  28. result = var(m, 0, 0);
  29. REQUIRE( result.n_cols == 100 );
  30. REQUIRE( result.n_rows == 1 );
  31. for (uword i = 0; i < 100; ++i)
  32. {
  33. REQUIRE( (double) result[i] == Approx(0.0) );
  34. }
  35. result = var(m, 1, 0);
  36. REQUIRE( result.n_cols == 100 );
  37. REQUIRE( result.n_rows == 1 );
  38. for (uword i = 0; i < 100; ++i)
  39. {
  40. REQUIRE( (double) result[i] == Approx(0.0) );
  41. }
  42. result = var(m, 1);
  43. REQUIRE( result.n_cols == 100 );
  44. REQUIRE( result.n_rows == 1 );
  45. for (uword i = 0; i < 100; ++i)
  46. {
  47. REQUIRE( (double) result[i] == Approx(0.0) );
  48. }
  49. SpCol<double> colres = var(m, 1, 1);
  50. REQUIRE( colres.n_cols == 1 );
  51. REQUIRE( colres.n_rows == 100 );
  52. for (uword i = 0; i < 100; ++i)
  53. {
  54. REQUIRE( (double) colres[i] == Approx(0.0) );
  55. }
  56. colres = var(m, 0, 1);
  57. REQUIRE( colres.n_cols == 1 );
  58. REQUIRE( colres.n_rows == 100 );
  59. for (uword i = 0; i < 100; ++i)
  60. {
  61. REQUIRE( (double) colres[i] == Approx(0.0) );
  62. }
  63. }
  64. TEST_CASE("fn_var_empty_cx_sparse_test")
  65. {
  66. SpMat<std::complex<double> > m(100, 100);
  67. SpRow<double> result = var(m);
  68. REQUIRE( result.n_cols == 100 );
  69. REQUIRE( result.n_rows == 1 );
  70. for (uword i = 0; i < 100; ++i)
  71. {
  72. REQUIRE( (double) result[i] == Approx(0.0) );
  73. }
  74. result = var(m, 0, 0);
  75. REQUIRE( result.n_cols == 100 );
  76. REQUIRE( result.n_rows == 1 );
  77. for (uword i = 0; i < 100; ++i)
  78. {
  79. REQUIRE( (double) result[i] == Approx(0.0) );
  80. }
  81. result = var(m, 1, 0);
  82. REQUIRE( result.n_cols == 100 );
  83. REQUIRE( result.n_rows == 1 );
  84. for (uword i = 0; i < 100; ++i)
  85. {
  86. REQUIRE( (double) result[i] == Approx(0.0) );
  87. }
  88. result = var(m, 1);
  89. REQUIRE( result.n_cols == 100 );
  90. REQUIRE( result.n_rows == 1 );
  91. for (uword i = 0; i < 100; ++i)
  92. {
  93. REQUIRE( (double) result[i] == Approx(0.0) );
  94. }
  95. SpCol<double> colres = var(m, 1, 1);
  96. REQUIRE( colres.n_cols == 1 );
  97. REQUIRE( colres.n_rows == 100 );
  98. for (uword i = 0; i < 100; ++i)
  99. {
  100. REQUIRE( (double) colres[i] == Approx(0.0) );
  101. }
  102. colres = var(m, 0, 1);
  103. REQUIRE( colres.n_cols == 1 );
  104. REQUIRE( colres.n_rows == 100 );
  105. for (uword i = 0; i < 100; ++i)
  106. {
  107. REQUIRE( (double) colres[i] == Approx(0.0) );
  108. }
  109. }
  110. TEST_CASE("fn_var_sparse_test")
  111. {
  112. // Create a random matrix and do variance testing on it, with varying levels
  113. // of nonzero (eventually this becomes a fully dense matrix).
  114. for (int i = 0; i < 10; ++i)
  115. {
  116. SpMat<double> x;
  117. x.sprandu(50, 75, ((double) (i + 1)) / 10);
  118. mat d(x);
  119. SpRow<double> rr = var(x);
  120. rowvec drr = var(d);
  121. REQUIRE( rr.n_rows == 1 );
  122. REQUIRE( rr.n_cols == 75 );
  123. for (uword j = 0; j < 75; ++j)
  124. {
  125. REQUIRE( drr[j] == Approx((double) rr[j]) );
  126. }
  127. rr = var(x, 0);
  128. REQUIRE( rr.n_rows == 1 );
  129. REQUIRE( rr.n_cols == 75 );
  130. for (uword j = 0; j < 75; ++j)
  131. {
  132. REQUIRE( drr[j] == Approx((double) rr[j]) );
  133. }
  134. rr = var(x, 1, 0);
  135. drr = var(d, 1, 0);
  136. REQUIRE( rr.n_rows == 1 );
  137. REQUIRE( rr.n_cols == 75 );
  138. for (uword j = 0; j < 75; ++j)
  139. {
  140. REQUIRE( drr[j] == Approx((double) rr[j]) );
  141. }
  142. SpCol<double> cr = var(x, 0, 1);
  143. vec dcr = var(d, 0, 1);
  144. REQUIRE( cr.n_rows == 50 );
  145. REQUIRE( cr.n_cols == 1 );
  146. for (uword j = 0; j < 50; ++j)
  147. {
  148. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  149. }
  150. cr = var(x, 1, 1);
  151. dcr = var(d, 1, 1);
  152. REQUIRE( cr.n_rows == 50 );
  153. REQUIRE( cr.n_cols == 1 );
  154. for (uword j = 0; j < 50; ++j)
  155. {
  156. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  157. }
  158. // Now on a subview.
  159. rr = var(x.submat(11, 11, 30, 45), 0, 0);
  160. drr = var(d.submat(11, 11, 30, 45), 0, 0);
  161. REQUIRE( rr.n_rows == 1 );
  162. REQUIRE( rr.n_cols == 35 );
  163. for (uword j = 0; j < 35; ++j)
  164. {
  165. REQUIRE( drr[j] == Approx((double) rr[j]) );
  166. }
  167. rr = var(x.submat(11, 11, 30, 45), 1, 0);
  168. drr = var(d.submat(11, 11, 30, 45), 1, 0);
  169. REQUIRE( rr.n_rows == 1 );
  170. REQUIRE( rr.n_cols == 35 );
  171. for (uword j = 0; j < 35; ++j)
  172. {
  173. REQUIRE( drr[j] == Approx((double) rr[j]) );
  174. }
  175. cr = var(x.submat(11, 11, 30, 45), 0, 1);
  176. dcr = var(d.submat(11, 11, 30, 45), 0, 1);
  177. REQUIRE( cr.n_rows == 20 );
  178. REQUIRE( cr.n_cols == 1 );
  179. for (uword j = 0; j < 20; ++j)
  180. {
  181. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  182. }
  183. cr = var(x.submat(11, 11, 30, 45), 1, 1);
  184. dcr = var(d.submat(11, 11, 30, 45), 1, 1);
  185. REQUIRE( cr.n_rows == 20 );
  186. REQUIRE( cr.n_cols == 1 );
  187. for (uword j = 0; j < 20; ++j)
  188. {
  189. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  190. }
  191. // Now on an SpOp (spop_scalar_times)
  192. rr = var(3.0 * x, 0, 0);
  193. drr = var(3.0 * d, 0, 0);
  194. REQUIRE( rr.n_rows == 1 );
  195. REQUIRE( rr.n_cols == 75 );
  196. for (uword j = 0; j < 75; ++j)
  197. {
  198. REQUIRE( drr[j] == Approx((double) rr[j]) );
  199. }
  200. rr = var(3.0 * x, 1, 0);
  201. drr = var(3.0 * d, 1, 0);
  202. REQUIRE( rr.n_rows == 1 );
  203. REQUIRE( rr.n_cols == 75 );
  204. for (uword j = 0; j < 75; ++j)
  205. {
  206. REQUIRE( drr[j] == Approx((double) rr[j]) );
  207. }
  208. cr = var(4.5 * x, 0, 1);
  209. dcr = var(4.5 * d, 0, 1);
  210. REQUIRE( cr.n_rows == 50 );
  211. REQUIRE( cr.n_cols == 1 );
  212. for (uword j = 0; j < 50; ++j)
  213. {
  214. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  215. }
  216. cr = var(4.5 * x, 1, 1);
  217. dcr = var(4.5 * d, 1, 1);
  218. REQUIRE( cr.n_rows == 50 );
  219. REQUIRE( cr.n_cols == 1 );
  220. for (uword j = 0; j < 50; ++j)
  221. {
  222. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  223. }
  224. // Now on an SpGlue!
  225. SpMat<double> y;
  226. y.sprandu(50, 75, 0.3);
  227. mat e(y);
  228. rr = var(x + y);
  229. drr = var(d + e);
  230. REQUIRE( rr.n_rows == 1 );
  231. REQUIRE( rr.n_cols == 75 );
  232. for (uword j = 0; j < 75; ++j)
  233. {
  234. REQUIRE( drr[j] == Approx((double) rr[j]) );
  235. }
  236. rr = var(x + y, 1);
  237. drr = var(d + e, 1);
  238. REQUIRE( rr.n_rows == 1 );
  239. REQUIRE( rr.n_cols == 75 );
  240. for (uword j = 0; j < 75; ++j)
  241. {
  242. REQUIRE( drr[j] == Approx((double) rr[j]) );
  243. }
  244. cr = var(x + y, 0, 1);
  245. dcr = var(d + e, 0, 1);
  246. REQUIRE( cr.n_rows == 50 );
  247. REQUIRE( cr.n_cols == 1 );
  248. for (uword j = 0; j < 50; ++j)
  249. {
  250. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  251. }
  252. cr = var(x + y, 1, 1);
  253. dcr = var(d + e, 1, 1);
  254. REQUIRE( cr.n_rows == 50 );
  255. REQUIRE( cr.n_cols == 1 );
  256. for (uword j = 0; j < 50; ++j)
  257. {
  258. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  259. }
  260. }
  261. }
  262. TEST_CASE("fn_var_sparse_cx_test")
  263. {
  264. // Create a random matrix and do variance testing on it, with varying levels
  265. // of nonzero (eventually this becomes a fully dense matrix).
  266. for (int i = 0; i < 10; ++i)
  267. {
  268. SpMat<std::complex<double> > x;
  269. x.sprandu(50, 75, ((double) (i + 1)) / 10);
  270. cx_mat d(x);
  271. SpRow<double> rr = var(x);
  272. rowvec drr = var(d);
  273. REQUIRE( rr.n_rows == 1 );
  274. REQUIRE( rr.n_cols == 75 );
  275. for (uword j = 0; j < 75; ++j)
  276. {
  277. REQUIRE( drr[j] == Approx((double) rr[j]) );
  278. }
  279. rr = var(x, 0);
  280. REQUIRE( rr.n_rows == 1 );
  281. REQUIRE( rr.n_cols == 75 );
  282. for (uword j = 0; j < 75; ++j)
  283. {
  284. REQUIRE( drr[j] == Approx((double) rr[j]) );
  285. }
  286. rr = var(x, 1, 0);
  287. drr = var(d, 1, 0);
  288. REQUIRE( rr.n_rows == 1 );
  289. REQUIRE( rr.n_cols == 75 );
  290. for (uword j = 0; j < 75; ++j)
  291. {
  292. REQUIRE( drr[j] == Approx((double) rr[j]) );
  293. }
  294. SpCol<double> cr = var(x, 0, 1);
  295. vec dcr = var(d, 0, 1);
  296. REQUIRE( cr.n_rows == 50 );
  297. REQUIRE( cr.n_cols == 1 );
  298. for (uword j = 0; j < 50; ++j)
  299. {
  300. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  301. }
  302. cr = var(x, 1, 1);
  303. dcr = var(d, 1, 1);
  304. REQUIRE( cr.n_rows == 50 );
  305. REQUIRE( cr.n_cols == 1 );
  306. for (uword j = 0; j < 50; ++j)
  307. {
  308. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  309. }
  310. // Now on a subview.
  311. rr = var(x.submat(11, 11, 30, 45), 0, 0);
  312. drr = var(d.submat(11, 11, 30, 45), 0, 0);
  313. REQUIRE( rr.n_rows == 1 );
  314. REQUIRE( rr.n_cols == 35 );
  315. for (uword j = 0; j < 35; ++j)
  316. {
  317. REQUIRE( drr[j] == Approx((double) rr[j]) );
  318. }
  319. rr = var(x.submat(11, 11, 30, 45), 1, 0);
  320. drr = var(d.submat(11, 11, 30, 45), 1, 0);
  321. REQUIRE( rr.n_rows == 1 );
  322. REQUIRE( rr.n_cols == 35 );
  323. for (uword j = 0; j < 35; ++j)
  324. {
  325. REQUIRE( drr[j] == Approx((double) rr[j]) );
  326. }
  327. cr = var(x.submat(11, 11, 30, 45), 0, 1);
  328. dcr = var(d.submat(11, 11, 30, 45), 0, 1);
  329. REQUIRE( cr.n_rows == 20 );
  330. REQUIRE( cr.n_cols == 1 );
  331. for (uword j = 0; j < 20; ++j)
  332. {
  333. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  334. }
  335. cr = var(x.submat(11, 11, 30, 45), 1, 1);
  336. dcr = var(d.submat(11, 11, 30, 45), 1, 1);
  337. REQUIRE( cr.n_rows == 20 );
  338. REQUIRE( cr.n_cols == 1 );
  339. for (uword j = 0; j < 20; ++j)
  340. {
  341. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  342. }
  343. // Now on an SpOp (spop_scalar_times)
  344. rr = var(3.0 * x, 0, 0);
  345. drr = var(3.0 * d, 0, 0);
  346. REQUIRE( rr.n_rows == 1 );
  347. REQUIRE( rr.n_cols == 75 );
  348. for (uword j = 0; j < 75; ++j)
  349. {
  350. REQUIRE( drr[j] == Approx((double) rr[j]) );
  351. }
  352. rr = var(3.0 * x, 1, 0);
  353. drr = var(3.0 * d, 1, 0);
  354. REQUIRE( rr.n_rows == 1 );
  355. REQUIRE( rr.n_cols == 75 );
  356. for (uword j = 0; j < 75; ++j)
  357. {
  358. REQUIRE( drr[j] == Approx((double) rr[j]) );
  359. }
  360. cr = var(4.5 * x, 0, 1);
  361. dcr = var(4.5 * d, 0, 1);
  362. REQUIRE( cr.n_rows == 50 );
  363. REQUIRE( cr.n_cols == 1 );
  364. for (uword j = 0; j < 50; ++j)
  365. {
  366. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  367. }
  368. cr = var(4.5 * x, 1, 1);
  369. dcr = var(4.5 * d, 1, 1);
  370. REQUIRE( cr.n_rows == 50 );
  371. REQUIRE( cr.n_cols == 1 );
  372. for (uword j = 0; j < 50; ++j)
  373. {
  374. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  375. }
  376. // Now on an SpGlue!
  377. SpMat<std::complex<double> > y;
  378. y.sprandu(50, 75, 0.3);
  379. cx_mat e(y);
  380. rr = var(x + y);
  381. drr = var(d + e);
  382. REQUIRE( rr.n_rows == 1 );
  383. REQUIRE( rr.n_cols == 75 );
  384. for (uword j = 0; j < 75; ++j)
  385. {
  386. REQUIRE( drr[j] == Approx((double) rr[j]) );
  387. }
  388. rr = var(x + y, 1);
  389. drr = var(d + e, 1);
  390. REQUIRE( rr.n_rows == 1 );
  391. REQUIRE( rr.n_cols == 75 );
  392. for (uword j = 0; j < 75; ++j)
  393. {
  394. REQUIRE( drr[j] == Approx((double) rr[j]) );
  395. }
  396. cr = var(x + y, 0, 1);
  397. dcr = var(d + e, 0, 1);
  398. REQUIRE( cr.n_rows == 50 );
  399. REQUIRE( cr.n_cols == 1 );
  400. for (uword j = 0; j < 50; ++j)
  401. {
  402. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  403. }
  404. cr = var(x + y, 1, 1);
  405. dcr = var(d + e, 1, 1);
  406. REQUIRE( cr.n_rows == 50 );
  407. REQUIRE( cr.n_cols == 1 );
  408. for (uword j = 0; j < 50; ++j)
  409. {
  410. REQUIRE( dcr[j] == Approx((double) cr[j]) );
  411. }
  412. }
  413. }
  414. TEST_CASE("fn_var_sparse_alias_test")
  415. {
  416. sp_mat s;
  417. s.sprandu(70, 70, 0.3);
  418. mat d(s);
  419. s = var(s);
  420. d = var(d);
  421. REQUIRE( d.n_rows == s.n_rows );
  422. REQUIRE( d.n_cols == s.n_cols );
  423. for (uword i = 0; i < d.n_elem; ++i)
  424. {
  425. REQUIRE(d[i] == Approx((double) s[i]) );
  426. }
  427. s.sprandu(70, 70, 0.3);
  428. d = s;
  429. s = var(s, 1);
  430. d = var(d, 1);
  431. for (uword i = 0; i < d.n_elem; ++i)
  432. {
  433. REQUIRE( d[i] == Approx((double) s[i]) );
  434. }
  435. }