gen_ones.cpp 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. // Copyright 2015 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2015 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. #include <armadillo>
  16. #include "catch.hpp"
  17. using namespace arma;
  18. TEST_CASE("gen_ones_1")
  19. {
  20. mat A(5,6,fill::ones);
  21. REQUIRE( accu(A) == Approx(double(5*6)) );
  22. REQUIRE( A.n_rows == 5 );
  23. REQUIRE( A.n_cols == 6 );
  24. mat B(5,6,fill::randu);
  25. B.ones();
  26. REQUIRE( accu(B) == Approx(double(5*6)) );
  27. REQUIRE( B.n_rows == 5 );
  28. REQUIRE( B.n_cols == 6 );
  29. mat C = ones<mat>(5,6);
  30. REQUIRE( accu(C) == Approx(double(5*6)) );
  31. REQUIRE( C.n_rows == 5 );
  32. REQUIRE( C.n_cols == 6 );
  33. mat D; D = ones<mat>(5,6);
  34. REQUIRE( accu(D) == Approx(double(5*6)) );
  35. REQUIRE( D.n_rows == 5 );
  36. REQUIRE( D.n_cols == 6 );
  37. mat E; E = 2*ones<mat>(5,6);
  38. REQUIRE( accu(E) == Approx(double(2*5*6)) );
  39. REQUIRE( E.n_rows == 5 );
  40. REQUIRE( E.n_cols == 6 );
  41. }
  42. TEST_CASE("gen_ones_2")
  43. {
  44. mat A(5,6,fill::zeros);
  45. A.col(1).ones();
  46. REQUIRE( accu(A.col(0)) == Approx(0.0) );
  47. REQUIRE( accu(A.col(1)) == Approx(double(A.n_rows)) );
  48. REQUIRE( accu(A.col(2)) == Approx(0.0) );
  49. mat B(5,6,fill::zeros);
  50. B.row(1).ones();
  51. REQUIRE( accu(B.row(0)) == Approx(0.0) );
  52. REQUIRE( accu(B.row(1)) == Approx(double(B.n_cols)) );
  53. REQUIRE( accu(B.row(2)) == Approx(0.0) );
  54. mat C(5,6,fill::zeros);
  55. C(span(1,3),span(1,4)).ones();
  56. REQUIRE( accu(C.head_cols(1)) == Approx(0.0) );
  57. REQUIRE( accu(C.head_rows(1)) == Approx(0.0) );
  58. REQUIRE( accu(C.tail_cols(1)) == Approx(0.0) );
  59. REQUIRE( accu(C.tail_rows(1)) == Approx(0.0) );
  60. REQUIRE( accu(C(span(1,3),span(1,4))) == Approx(double(3*4)) );
  61. mat D(5,6,fill::zeros);
  62. D.diag().ones();
  63. REQUIRE( accu(D.diag()) == Approx(double(5)) );
  64. }
  65. TEST_CASE("gen_ones_3")
  66. {
  67. mat A(5,6,fill::zeros);
  68. uvec indices = { 2, 4, 6 };
  69. A(indices).ones();
  70. REQUIRE( accu(A) == Approx(double(3)) );
  71. REQUIRE( A(0) == Approx(0.0) );
  72. REQUIRE( A(A.n_elem-1) == Approx(0.0) );
  73. REQUIRE( A(indices(0)) == Approx(1.0) );
  74. REQUIRE( A(indices(1)) == Approx(1.0) );
  75. REQUIRE( A(indices(2)) == Approx(1.0) );
  76. }