fn_trace.cpp 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. // Copyright 2015 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2015 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. #include <armadillo>
  16. #include "catch.hpp"
  17. using namespace arma;
  18. TEST_CASE("fn_trace_1")
  19. {
  20. mat A =
  21. "\
  22. 0.061198 0.201990 0.019678 -0.493936 -0.126745 0.051408;\
  23. 0.437242 0.058956 -0.149362 -0.045465 0.296153 0.035437;\
  24. -0.492474 -0.031309 0.314156 0.419733 0.068317 -0.454499;\
  25. 0.336352 0.411541 0.458476 -0.393139 -0.135040 0.373833;\
  26. 0.239585 -0.428913 -0.406953 -0.291020 -0.353768 0.258704;\
  27. ";
  28. vec diagonal = { 0.061198, 0.058956, 0.314156, -0.393139, -0.353768 };
  29. REQUIRE( accu( trace(A) - accu(diagonal) ) == Approx(0.0) );
  30. REQUIRE( accu( trace(2*A) - accu(2*diagonal) ) == Approx(0.0) );
  31. REQUIRE( accu( trace(A+A) - accu(diagonal+diagonal) ) == Approx(0.0) );
  32. // REQUIRE_THROWS( );
  33. }
  34. TEST_CASE("fn_trace_spmat")
  35. {
  36. SpMat<double> a(6, 6);
  37. a(0, 0) = 3.0;
  38. a(2, 1) = 4.4;
  39. a(4, 1) = 1.2;
  40. a(0, 2) = 3.1;
  41. a(1, 2) = 3.2;
  42. a(2, 2) = 3.3;
  43. a(3, 3) = 4.0;
  44. a(5, 3) = 6.0;
  45. a(5, 4) = 5.9;
  46. a(5, 5) = 1.2;
  47. REQUIRE( trace(a) == Approx(11.5) );
  48. REQUIRE( trace(a.submat(2, 2, 4, 4)) == Approx(7.3) );
  49. }
  50. TEST_CASE("fn_trace_spmat_mul")
  51. {
  52. // Test trace(SpMat * SpMat) and ensure the result is the same as if we
  53. // pre-multiplied the matrices.
  54. sp_mat a;
  55. a.sprandu(20, 20, 0.1);
  56. sp_mat b;
  57. b.sprandu(20, 20, 0.1);
  58. sp_mat c = a * b;
  59. const double trc = trace(c);
  60. const double trab = trace(a * b);
  61. REQUIRE( trc == Approx(trab) );
  62. }
  63. TEST_CASE("fn_trace_spmat_t_mul")
  64. {
  65. // Test trace(SpMat.t() * SpMat) and ensure the result is the same as if we
  66. // pre-multiplied the matrices.
  67. sp_mat a;
  68. a.sprandu(20, 20, 0.1);
  69. sp_mat b;
  70. b.sprandu(20, 20, 0.1);
  71. sp_mat c = a.t() * b;
  72. const double trc = trace(c);
  73. const double trab = trace(a.t() * b);
  74. REQUIRE( trc == Approx(trab) );
  75. }