fn_min.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. // Copyright 2011-2017 Ryan Curtin (http://www.ratml.org/)
  2. // Copyright 2017 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. #include <armadillo>
  16. #include "catch.hpp"
  17. using namespace arma;
  18. TEST_CASE("fn_min_weird_operation")
  19. {
  20. mat a(10, 10);
  21. mat b(25, 10);
  22. a.randn();
  23. b.randn();
  24. mat output = a * b.t();
  25. uword real_min;
  26. uword operation_min;
  27. const double mval = output.min(real_min);
  28. const double other_mval = (a * b.t()).min(operation_min);
  29. REQUIRE( real_min == operation_min );
  30. REQUIRE( mval == Approx(other_mval) );
  31. }
  32. TEST_CASE("fn_min_weird_sparse_operation")
  33. {
  34. sp_mat a(10, 10);
  35. sp_mat b(25, 10);
  36. a.sprandn(10, 10, 0.3);
  37. b.sprandn(25, 10, 0.3);
  38. sp_mat output = a * b.t();
  39. uword real_min;
  40. uword operation_min;
  41. const double mval = output.min(real_min);
  42. const double other_mval = (a * b.t()).min(operation_min);
  43. REQUIRE( real_min == operation_min );
  44. REQUIRE( mval == Approx(other_mval) );
  45. }
  46. TEST_CASE("fn_min_sp_subview_test")
  47. {
  48. // We will assume subview.at() works and returns points within the bounds of
  49. // the matrix, so we just have to ensure the results are the same as
  50. // Mat.min()...
  51. for (size_t r = 50; r < 150; ++r)
  52. {
  53. sp_mat x;
  54. x.sprandn(r, r, 0.3);
  55. uword x_min;
  56. uword x_subview_min1;
  57. uword x_subview_min2;
  58. uword x_subview_min3;
  59. const double mval = x.min(x_min);
  60. const double mval1 = x.submat(0, 0, r - 1, r - 1).min(x_subview_min1);
  61. const double mval2 = x.cols(0, r - 1).min(x_subview_min2);
  62. const double mval3 = x.rows(0, r - 1).min(x_subview_min3);
  63. if (mval != 0.0)
  64. {
  65. REQUIRE( x_min == x_subview_min1 );
  66. REQUIRE( x_min == x_subview_min2 );
  67. REQUIRE( x_min == x_subview_min3 );
  68. REQUIRE( mval == Approx(mval1) );
  69. REQUIRE( mval == Approx(mval2) );
  70. REQUIRE( mval == Approx(mval3) );
  71. }
  72. }
  73. }
  74. TEST_CASE("fn_min_spsubview_col_test")
  75. {
  76. for (size_t r = 10; r < 50; ++r)
  77. {
  78. sp_vec x;
  79. x.sprandn(r, 1, 0.3);
  80. uword x_min;
  81. uword x_subview_min1;
  82. uword x_subview_min2;
  83. const double mval = x.min(x_min);
  84. const double mval1 = x.submat(0, 0, r - 1, 0).min(x_subview_min1);
  85. const double mval2 = x.rows(0, r - 1).min(x_subview_min2);
  86. if (mval != 0.0)
  87. {
  88. REQUIRE( x_min == x_subview_min1 );
  89. REQUIRE( x_min == x_subview_min2 );
  90. REQUIRE( mval == Approx(mval1) );
  91. REQUIRE( mval == Approx(mval2) );
  92. }
  93. }
  94. }
  95. TEST_CASE("fn_min_spsubview_row_min_test")
  96. {
  97. for (size_t r = 10; r < 50; ++r)
  98. {
  99. sp_rowvec x;
  100. x.sprandn(1, r, 0.3);
  101. uword x_min;
  102. uword x_subview_min1;
  103. uword x_subview_min2;
  104. const double mval = x.min(x_min);
  105. const double mval1 = x.submat(0, 0, 0, r - 1).min(x_subview_min1);
  106. const double mval2 = x.cols(0, r - 1).min(x_subview_min2);
  107. if (mval != 0.0)
  108. {
  109. REQUIRE( x_min == x_subview_min1 );
  110. REQUIRE( x_min == x_subview_min2 );
  111. REQUIRE( mval == Approx(mval1) );
  112. REQUIRE( mval == Approx(mval2) );
  113. }
  114. }
  115. }
  116. TEST_CASE("fn_min_spincompletesubview_min_test")
  117. {
  118. for (size_t r = 50; r < 150; ++r)
  119. {
  120. sp_mat x;
  121. x.sprandn(r, r, 0.3);
  122. uword x_min;
  123. uword x_subview_min1;
  124. uword x_subview_min2;
  125. uword x_subview_min3;
  126. const double mval = x.min(x_min);
  127. const double mval1 = x.submat(1, 1, r - 2, r - 2).min(x_subview_min1);
  128. const double mval2 = x.cols(1, r - 2).min(x_subview_min2);
  129. const double mval3 = x.rows(1, r - 2).min(x_subview_min3);
  130. uword row, col;
  131. x.min(row, col);
  132. if (row != 0 && row != r - 1 && col != 0 && col != r - 1 && mval != 0.0)
  133. {
  134. uword srow, scol;
  135. srow = x_subview_min1 % (r - 2);
  136. scol = x_subview_min1 / (r - 2);
  137. REQUIRE( x_min == (srow + 1) + r * (scol + 1) );
  138. REQUIRE( x_min == x_subview_min2 + r );
  139. srow = x_subview_min3 % (r - 2);
  140. scol = x_subview_min3 / (r - 2);
  141. REQUIRE( x_min == (srow + 1) + r * scol );
  142. REQUIRE( mval == Approx(mval1) );
  143. REQUIRE( mval == Approx(mval2) );
  144. REQUIRE( mval == Approx(mval3) );
  145. }
  146. }
  147. }
  148. TEST_CASE("fn_min_spincompletesubview_col_min_test")
  149. {
  150. for (size_t r = 10; r < 50; ++r)
  151. {
  152. sp_vec x;
  153. x.sprandu(r, 1, 0.3);
  154. uword x_min;
  155. uword x_subview_min1;
  156. uword x_subview_min2;
  157. const double mval = x.min(x_min);
  158. const double mval1 = x.submat(1, 0, r - 2, 0).min(x_subview_min1);
  159. const double mval2 = x.rows(1, r - 2).min(x_subview_min2);
  160. if (x_min != 0 && x_min != r - 1 && mval != 0.0)
  161. {
  162. REQUIRE( x_min == x_subview_min1 + 1 );
  163. REQUIRE( x_min == x_subview_min2 + 1 );
  164. REQUIRE( mval == Approx(mval1) );
  165. REQUIRE( mval == Approx(mval2) );
  166. }
  167. }
  168. }
  169. TEST_CASE("fn_min_spincompletesubview_row_min_test")
  170. {
  171. for (size_t r = 10; r < 50; ++r)
  172. {
  173. sp_rowvec x;
  174. x.sprandn(1, r, 0.3);
  175. uword x_min;
  176. uword x_subview_min1;
  177. uword x_subview_min2;
  178. const double mval = x.min(x_min);
  179. const double mval1 = x.submat(0, 1, 0, r - 2).min(x_subview_min1);
  180. const double mval2 = x.cols(1, r - 2).min(x_subview_min2);
  181. if (mval != 0.0 && x_min != 0 && x_min != r - 1)
  182. {
  183. REQUIRE( x_min == x_subview_min1 + 1 );
  184. REQUIRE( x_min == x_subview_min2 + 1 );
  185. REQUIRE( mval == Approx(mval1) );
  186. REQUIRE( mval == Approx(mval2) );
  187. }
  188. }
  189. }
  190. TEST_CASE("fn_min_sp_cx_subview_min_test")
  191. {
  192. // We will assume subview.at() works and returns points within the bounds of
  193. // the matrix, so we just have to ensure the results are the same as
  194. // Mat.min()...
  195. for (size_t r = 50; r < 150; ++r)
  196. {
  197. sp_cx_mat x;
  198. x.sprandn(r, r, 0.3);
  199. uword x_min;
  200. uword x_subview_min1;
  201. uword x_subview_min2;
  202. uword x_subview_min3;
  203. const std::complex<double> mval = x.min(x_min);
  204. const std::complex<double> mval1 = x.submat(0, 0, r - 1, r - 1).min(x_subview_min1);
  205. const std::complex<double> mval2 = x.cols(0, r - 1).min(x_subview_min2);
  206. const std::complex<double> mval3 = x.rows(0, r - 1).min(x_subview_min3);
  207. if (mval != std::complex<double>(0.0))
  208. {
  209. REQUIRE( x_min == x_subview_min1 );
  210. REQUIRE( x_min == x_subview_min2 );
  211. REQUIRE( x_min == x_subview_min3 );
  212. REQUIRE( mval.real() == Approx(mval1.real()) );
  213. REQUIRE( mval.imag() == Approx(mval1.imag()) );
  214. REQUIRE( mval.real() == Approx(mval2.real()) );
  215. REQUIRE( mval.imag() == Approx(mval2.imag()) );
  216. REQUIRE( mval.real() == Approx(mval3.real()) );
  217. REQUIRE( mval.imag() == Approx(mval3.imag()) );
  218. }
  219. }
  220. }
  221. TEST_CASE("fn_min_sp_cx_subview_col_min_test")
  222. {
  223. for (size_t r = 10; r < 50; ++r)
  224. {
  225. sp_cx_vec x;
  226. x.sprandn(r, 1, 0.3);
  227. uword x_min;
  228. uword x_subview_min1;
  229. uword x_subview_min2;
  230. const std::complex<double> mval = x.min(x_min);
  231. const std::complex<double> mval1 = x.submat(0, 0, r - 1, 0).min(x_subview_min1);
  232. const std::complex<double> mval2 = x.rows(0, r - 1).min(x_subview_min2);
  233. if (mval != std::complex<double>(0.0))
  234. {
  235. REQUIRE( x_min == x_subview_min1 );
  236. REQUIRE( x_min == x_subview_min2 );
  237. REQUIRE( mval.real() == Approx(mval1.real()) );
  238. REQUIRE( mval.imag() == Approx(mval1.imag()) );
  239. REQUIRE( mval.real() == Approx(mval2.real()) );
  240. REQUIRE( mval.imag() == Approx(mval2.imag()) );
  241. }
  242. }
  243. }
  244. TEST_CASE("fn_min_sp_cx_subview_row_min_test")
  245. {
  246. for (size_t r = 10; r < 50; ++r)
  247. {
  248. sp_cx_rowvec x;
  249. x.sprandn(1, r, 0.3);
  250. uword x_min;
  251. uword x_subview_min1;
  252. uword x_subview_min2;
  253. const std::complex<double> mval = x.min(x_min);
  254. const std::complex<double> mval1 = x.submat(0, 0, 0, r - 1).min(x_subview_min1);
  255. const std::complex<double> mval2 = x.cols(0, r - 1).min(x_subview_min2);
  256. if (mval != std::complex<double>(0.0))
  257. {
  258. REQUIRE( x_min == x_subview_min1 );
  259. REQUIRE( x_min == x_subview_min2 );
  260. REQUIRE( mval.real() == Approx(mval1.real()) );
  261. REQUIRE( mval.imag() == Approx(mval1.imag()) );
  262. REQUIRE( mval.real() == Approx(mval2.real()) );
  263. REQUIRE( mval.imag() == Approx(mval2.imag()) );
  264. }
  265. }
  266. }
  267. TEST_CASE("fn_min_sp_cx_incomplete_subview_min_test")
  268. {
  269. for (size_t r = 50; r < 150; ++r)
  270. {
  271. sp_cx_mat x;
  272. x.sprandn(r, r, 0.3);
  273. uword x_min;
  274. uword x_subview_min1;
  275. uword x_subview_min2;
  276. uword x_subview_min3;
  277. const std::complex<double> mval = x.min(x_min);
  278. const std::complex<double> mval1 = x.submat(1, 1, r - 2, r - 2).min(x_subview_min1);
  279. const std::complex<double> mval2 = x.cols(1, r - 2).min(x_subview_min2);
  280. const std::complex<double> mval3 = x.rows(1, r - 2).min(x_subview_min3);
  281. uword row, col;
  282. x.min(row, col);
  283. if (row != 0 && row != r - 1 && col != 0 && col != r - 1 && mval != std::complex<double>(0.0))
  284. {
  285. uword srow, scol;
  286. srow = x_subview_min1 % (r - 2);
  287. scol = x_subview_min1 / (r - 2);
  288. REQUIRE( x_min == (srow + 1) + r * (scol + 1) );
  289. REQUIRE( x_min == x_subview_min2 + r );
  290. srow = x_subview_min3 % (r - 2);
  291. scol = x_subview_min3 / (r - 2);
  292. REQUIRE( x_min == (srow + 1) + r * scol );
  293. REQUIRE( mval.real() == Approx(mval1.real()) );
  294. REQUIRE( mval.imag() == Approx(mval1.imag()) );
  295. REQUIRE( mval.real() == Approx(mval2.real()) );
  296. REQUIRE( mval.imag() == Approx(mval2.imag()) );
  297. REQUIRE( mval.real() == Approx(mval3.real()) );
  298. REQUIRE( mval.imag() == Approx(mval3.imag()) );
  299. }
  300. }
  301. }
  302. TEST_CASE("fn_min_sp_cx_incomplete_subview_col_min_test")
  303. {
  304. for (size_t r = 10; r < 50; ++r)
  305. {
  306. arma::sp_cx_vec x;
  307. x.sprandn(r, 1, 0.3);
  308. uword x_min;
  309. uword x_subview_min1;
  310. uword x_subview_min2;
  311. const std::complex<double> mval = x.min(x_min);
  312. const std::complex<double> mval1 = x.submat(1, 0, r - 2, 0).min(x_subview_min1);
  313. const std::complex<double> mval2 = x.rows(1, r - 2).min(x_subview_min2);
  314. if (x_min != 0 && x_min != r - 1 && mval != std::complex<double>(0.0))
  315. {
  316. REQUIRE( x_min == x_subview_min1 + 1 );
  317. REQUIRE( x_min == x_subview_min2 + 1 );
  318. REQUIRE( mval.real() == Approx(mval1.real()) );
  319. REQUIRE( mval.imag() == Approx(mval1.imag()) );
  320. REQUIRE( mval.real() == Approx(mval2.real()) );
  321. REQUIRE( mval.imag() == Approx(mval2.imag()) );
  322. }
  323. }
  324. }
  325. TEST_CASE("fn_min_sp_cx_incomplete_subview_row_min_test")
  326. {
  327. for (size_t r = 10; r < 50; ++r)
  328. {
  329. sp_cx_rowvec x;
  330. x.sprandn(1, r, 0.3);
  331. uword x_min;
  332. uword x_subview_min1;
  333. uword x_subview_min2;
  334. const std::complex<double> mval = x.min(x_min);
  335. const std::complex<double> mval1 = x.submat(0, 1, 0, r - 2).min(x_subview_min1);
  336. const std::complex<double> mval2 = x.cols(1, r - 2).min(x_subview_min2);
  337. if (x_min != 0 && x_min != r - 1 && mval != std::complex<double>(0.0))
  338. {
  339. REQUIRE( x_min == x_subview_min1 + 1 );
  340. REQUIRE( x_min == x_subview_min2 + 1 );
  341. REQUIRE( mval.real() == Approx(mval1.real()) );
  342. REQUIRE( mval.imag() == Approx(mval1.imag()) );
  343. REQUIRE( mval.real() == Approx(mval2.real()) );
  344. REQUIRE( mval.imag() == Approx(mval2.imag()) );
  345. }
  346. }
  347. }