fn_sum.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. // Copyright 2015 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2015 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. #include <armadillo>
  16. #include "catch.hpp"
  17. using namespace arma;
  18. TEST_CASE("fn_sum_1")
  19. {
  20. vec a = linspace<vec>(1,5,5);
  21. vec b = linspace<vec>(1,5,6);
  22. REQUIRE(sum(a) == Approx(15.0));
  23. REQUIRE(sum(b) == Approx(18.0));
  24. }
  25. TEST_CASE("sum2")
  26. {
  27. mat A =
  28. {
  29. { -0.78838, 0.69298, 0.41084, 0.90142 },
  30. { 0.49345, -0.12020, 0.78987, 0.53124 },
  31. { 0.73573, 0.52104, -0.22263, 0.40163 }
  32. };
  33. rowvec colsums = { 0.44080, 1.09382, 0.97808, 1.83429 };
  34. colvec rowsums =
  35. {
  36. 1.21686,
  37. 1.69436,
  38. 1.43577
  39. };
  40. REQUIRE( accu(abs(colsums - sum(A ))) == Approx(0.0) );
  41. REQUIRE( accu(abs(colsums - sum(A,0))) == Approx(0.0) );
  42. REQUIRE( accu(abs(rowsums - sum(A,1))) == Approx(0.0) );
  43. }
  44. TEST_CASE("sum3")
  45. {
  46. mat AA =
  47. {
  48. { -0.78838, 0.69298, 0.41084, 0.90142 },
  49. { 0.49345, -0.12020, 0.78987, 0.53124 },
  50. { 0.73573, 0.52104, -0.22263, 0.40163 }
  51. };
  52. cx_mat A = cx_mat(AA, 0.5*AA);
  53. rowvec re_colsums = { 0.44080, 1.09382, 0.97808, 1.83429 };
  54. cx_rowvec cx_colsums = cx_rowvec(re_colsums, 0.5*re_colsums);
  55. colvec re_rowsums =
  56. {
  57. 1.21686,
  58. 1.69436,
  59. 1.43577
  60. };
  61. cx_colvec cx_rowsums = cx_colvec(re_rowsums, 0.5*re_rowsums);
  62. REQUIRE( accu(abs(cx_colsums - sum(A ))) == Approx(0.0) );
  63. REQUIRE( accu(abs(cx_colsums - sum(A,0))) == Approx(0.0) );
  64. REQUIRE( accu(abs(cx_rowsums - sum(A,1))) == Approx(0.0) );
  65. }
  66. TEST_CASE("sum4")
  67. {
  68. mat X(100,101, fill::randu);
  69. REQUIRE( (sum(sum(X))/X.n_elem) == Approx(0.5).epsilon(0.02) );
  70. REQUIRE( (sum(sum(X(span::all,span::all)))/X.n_elem) == Approx(0.5).epsilon(0.02) );
  71. }
  72. TEST_CASE("sum_spmat")
  73. {
  74. SpCol<double> a(5);
  75. a(0) = 3.0;
  76. a(2) = 1.5;
  77. a(3) = 1.0;
  78. double res = sum(a);
  79. REQUIRE( res == Approx(5.5) );
  80. SpRow<double> b(5);
  81. b(1) = 1.3;
  82. b(2) = 4.4;
  83. b(4) = 1.0;
  84. res = sum(b);
  85. REQUIRE( res == Approx(6.7) );
  86. SpMat<double> c(8, 8);
  87. c(0, 0) = 3.0;
  88. c(1, 0) = 2.5;
  89. c(6, 0) = 2.1;
  90. c(4, 1) = 3.2;
  91. c(5, 1) = 1.1;
  92. c(1, 2) = 1.3;
  93. c(2, 3) = 4.1;
  94. c(5, 5) = 2.3;
  95. c(6, 5) = 3.1;
  96. c(7, 5) = 1.2;
  97. c(7, 7) = 3.4;
  98. SpMat<double> result = sum(c, 0);
  99. REQUIRE( result.n_rows == 1 );
  100. REQUIRE( result.n_cols == 8 );
  101. REQUIRE( result.n_nonzero == 6 );
  102. REQUIRE( (double) result(0, 0) == Approx(7.6) );
  103. REQUIRE( (double) result(0, 1) == Approx(4.3) );
  104. REQUIRE( (double) result(0, 2) == Approx(1.3) );
  105. REQUIRE( (double) result(0, 3) == Approx(4.1) );
  106. REQUIRE( (double) result(0, 4) == Approx(0.0) );
  107. REQUIRE( (double) result(0, 5) == Approx(6.6) );
  108. REQUIRE( (double) result(0, 6) == Approx(0.0) );
  109. REQUIRE( (double) result(0, 7) == Approx(3.4) );
  110. result = sum(c, 1);
  111. REQUIRE( result.n_rows == 8 );
  112. REQUIRE( result.n_cols == 1 );
  113. REQUIRE( result.n_nonzero == 7 );
  114. REQUIRE( (double) result(0, 0) == Approx(3.0) );
  115. REQUIRE( (double) result(1, 0) == Approx(3.8) );
  116. REQUIRE( (double) result(2, 0) == Approx(4.1) );
  117. REQUIRE( (double) result(3, 0) == Approx(0.0) );
  118. REQUIRE( (double) result(4, 0) == Approx(3.2) );
  119. REQUIRE( (double) result(5, 0) == Approx(3.4) );
  120. REQUIRE( (double) result(6, 0) == Approx(5.2) );
  121. REQUIRE( (double) result(7, 0) == Approx(4.6) );
  122. }