spcol.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. // Copyright 2011-2017 Ryan Curtin (http://www.ratml.org/)
  2. // Copyright 2017 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. #include <armadillo>
  16. #include "catch.hpp"
  17. using namespace arma;
  18. TEST_CASE("spcol_insert_test")
  19. {
  20. SpCol<double> sp;
  21. sp.set_size(10, 1);
  22. // Ensure everything is empty.
  23. for (size_t i = 0; i < 10; i++)
  24. REQUIRE( sp(i) == 0.0 );
  25. // Add an element.
  26. sp(5, 0) = 43.234;
  27. REQUIRE( sp.n_nonzero == 1 );
  28. REQUIRE( (double) sp(5, 0) == Approx(43.234) );
  29. // Remove the element.
  30. sp(5, 0) = 0.0;
  31. REQUIRE( sp.n_nonzero == 0 );
  32. }
  33. TEST_CASE("col_iterator_test")
  34. {
  35. SpCol<double> x(5, 1);
  36. x(3) = 3.1;
  37. x(0) = 4.2;
  38. x(1) = 3.3;
  39. x(1) = 5.5; // overwrite
  40. x(2) = 4.5;
  41. x(4) = 6.4;
  42. SpCol<double>::iterator it = x.begin();
  43. REQUIRE( (double) *it == Approx(4.2) );
  44. REQUIRE( it.row() == 0 );
  45. REQUIRE( it.col() == 0 );
  46. ++it;
  47. REQUIRE( (double) *it == Approx(5.5) );
  48. REQUIRE( it.row() == 1);
  49. REQUIRE( it.col() == 0);
  50. ++it;
  51. REQUIRE( (double) *it == Approx(4.5) );
  52. REQUIRE( it.row() == 2 );
  53. REQUIRE( it.col() == 0 );
  54. ++it;
  55. REQUIRE( (double) *it == Approx(3.1) );
  56. REQUIRE( it.row() == 3 );
  57. REQUIRE( it.col() == 0 );
  58. ++it;
  59. REQUIRE( (double) *it == Approx(6.4) );
  60. REQUIRE( it.row() == 4 );
  61. REQUIRE( it.col() == 0 );
  62. ++it;
  63. REQUIRE( it == x.end() );
  64. // Now let's go backwards.
  65. --it; // Get it off the end.
  66. REQUIRE( (double) *it == Approx(6.4) );
  67. REQUIRE( it.row() == 4 );
  68. REQUIRE( it.col() == 0 );
  69. --it;
  70. REQUIRE( (double) *it == Approx(3.1) );
  71. REQUIRE( it.row() == 3);
  72. REQUIRE( it.col() == 0);
  73. --it;
  74. REQUIRE( (double) *it == Approx(4.5) );
  75. REQUIRE( it.row() == 2);
  76. REQUIRE( it.col() == 0);
  77. --it;
  78. REQUIRE( (double) *it == Approx(5.5) );
  79. REQUIRE( it.row() == 1 );
  80. REQUIRE( it.col() == 0 );
  81. --it;
  82. REQUIRE( (double) *it == Approx(4.2) );
  83. REQUIRE( it.row() == 0 );
  84. REQUIRE( it.col() == 0 );
  85. REQUIRE( it == x.begin() );
  86. // Try removing an element we iterated to.
  87. ++it;
  88. ++it;
  89. *it = 0;
  90. REQUIRE( x.n_nonzero == 4 );
  91. }
  92. TEST_CASE("basic_sp_col_operator_test")
  93. {
  94. // +=, -=, *=, /=, %=
  95. SpCol<double> a(6, 1);
  96. a(0) = 3.4;
  97. a(1) = 2.0;
  98. SpCol<double> b(6, 1);
  99. b(0) = 3.4;
  100. b(3) = 0.4;
  101. double addResult[6] = {6.8, 2.0, 0.0, 0.4, 0.0, 0.0};
  102. double subResult[6] = {0.0, 2.0, 0.0, -0.4, 0.0, 0.0};
  103. SpCol<double> out = a;
  104. out += b;
  105. REQUIRE( out.n_nonzero == 3 );
  106. for (u32 r = 0; r < 6; r++)
  107. {
  108. REQUIRE( (double) out(r) == Approx(addResult[r]) );
  109. }
  110. out = a;
  111. out -= b;
  112. REQUIRE( out.n_nonzero == 2 );
  113. for (u32 r = 0; r < 6; r++)
  114. {
  115. REQUIRE( (double) out(r) == Approx(subResult[r]) );
  116. }
  117. }
  118. /*
  119. BOOST_AUTO_TEST_CASE(SparseSparseColMultiplicationTest) {
  120. SpCol<double> spaa(4, 1);
  121. SpMat<double> spbb(1, 4);
  122. spaa(0, 0) = 321.2;
  123. spaa(1, 0) = .123;
  124. spaa(2, 0) = 231.4;
  125. spaa(3, 0) = .03214;
  126. spbb(0, 0) = 32.23;
  127. spbb(0, 1) = 5.1;
  128. spbb(0, 2) = 4.4;
  129. spbb(0, 3) = .88;
  130. SpMat<double> precision = spaa;
  131. precision *= spbb; //Wolfram alpha insisted on rounding..
  132. spaa *= spbb;
  133. for (size_t i = 0; i < 4; i++)
  134. for (size_t j = 0; j < 4; j++)
  135. BOOST_REQUIRE_CLOSE((double) spaa(i, j), (double) precision(i, j), 1e-5);
  136. }
  137. */
  138. TEST_CASE("spcol_shed_row_test")
  139. {
  140. // On an SpCol
  141. SpCol<int> e(10);
  142. e(1) = 5;
  143. e(4) = 56;
  144. e(5) = 6;
  145. e(7) = 4;
  146. e(8) = 2;
  147. e(9) = -1;
  148. e.shed_rows(4, 7);
  149. REQUIRE( e.n_cols == 1 );
  150. REQUIRE( e.n_rows == 6 );
  151. REQUIRE( e.n_nonzero == 3 );
  152. REQUIRE( (int) e[0] == 0 );
  153. REQUIRE( (int) e[1] == 5 );
  154. REQUIRE( (int) e[2] == 0 );
  155. REQUIRE( (int) e[3] == 0 );
  156. REQUIRE( (int) e[4] == 2 );
  157. REQUIRE( (int) e[5] == -1 );
  158. }
  159. TEST_CASE("spcol_col_constructor")
  160. {
  161. SpMat<double> m(100, 100);
  162. m.sprandu(100, 100, 0.3);
  163. SpCol<double> c = m.col(0);
  164. vec v(c);
  165. for (uword i = 0; i < 100; ++i)
  166. {
  167. REQUIRE( v(i) == (double) c(i) );
  168. }
  169. }