fn_diagmat.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. // Copyright 2015 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2015 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. #include <armadillo>
  16. #include "catch.hpp"
  17. using namespace arma;
  18. TEST_CASE("fn_diagmat_1")
  19. {
  20. mat A =
  21. {
  22. { -0.78838, 0.69298, 0.41084, 0.90142 },
  23. { 0.49345, -0.12020, 0.78987, 0.53124 },
  24. { 0.73573, 0.52104, -0.22263, 0.40163 }
  25. };
  26. mat Ap1 =
  27. {
  28. { -0.0 , 0.69298, 0.0 , 0.0 },
  29. { 0.0 , 0.0 , 0.78987, 0.0 },
  30. { 0.0 , 0.0 , 0.0 , 0.40163 }
  31. };
  32. mat Amain =
  33. {
  34. { -0.78838, 0.0 , 0.0 , 0.0 },
  35. { 0.0 , -0.12020, 0.0 , 0.0 },
  36. { 0.0 , 0.0 , -0.22263, 0.0 }
  37. };
  38. mat Am1 =
  39. {
  40. { 0.0 , 0.0 , 0.0 , 0.0 },
  41. { 0.49345, 0.0 , 0.0 , 0.0 },
  42. { 0.0 , 0.52104, 0.0 , 0.0 }
  43. };
  44. REQUIRE( accu(abs(diagmat(A ) - Amain)) == Approx(0.0 ) );
  45. REQUIRE( accu(abs(diagmat(A, 0) - Amain)) == Approx(0.0 ) );
  46. REQUIRE( accu(abs(diagmat(A, 1) - Ap1 )) == Approx(0.0 ) );
  47. REQUIRE( accu(abs(diagmat(A,-1) - Am1 )) == Approx(0.0 ) );
  48. }
  49. TEST_CASE("fn_diagmat_2")
  50. {
  51. mat A =
  52. {
  53. { -0.78838, 0.69298, 0.41084, 0.90142 },
  54. { 0.49345, -0.12020, 0.78987, 0.53124 },
  55. { 0.73573, 0.52104, -0.22263, 0.40163 }
  56. };
  57. vec dp1 = { 0.69298, 0.78987, 0.40163 };
  58. vec dmain = { -0.78838, -0.12020, -0.22263 };
  59. vec dm1 = { 0.49345, 0.52104 };
  60. mat Ap1 (size(A),fill::zeros); Ap1.diag( 1) = dp1;
  61. mat Amain(size(A),fill::zeros); Amain.diag( ) = dmain;
  62. mat Am1 (size(A),fill::zeros); Am1.diag(-1) = dm1;
  63. REQUIRE( accu(abs(diagmat(A ) - Amain)) == Approx(0.0) );
  64. REQUIRE( accu(abs(diagmat(A, 0) - Amain)) == Approx(0.0) );
  65. REQUIRE( accu(abs(diagmat(A, 1) - Ap1)) == Approx(0.0) );
  66. REQUIRE( accu(abs(diagmat(A,-1) - Am1)) == Approx(0.0) );
  67. }
  68. TEST_CASE("fn_diagmat_3")
  69. {
  70. mat A =
  71. {
  72. { -0.78838, 0.69298, 0.41084, 0.90142 },
  73. { 0.49345, -0.12020, 0.78987, 0.53124 },
  74. { 0.73573, 0.52104, -0.22263, 0.40163 }
  75. };
  76. mat B =
  77. "\
  78. 0.171180 0.106848 0.490557 -0.079866;\
  79. 0.073839 -0.428277 -0.049842 0.398193;\
  80. -0.030523 0.366160 0.260348 -0.412238;\
  81. ";
  82. mat Asub = A(span::all,span(0,2));
  83. mat At = A.t();
  84. mat Bsub = B(span::all,span(0,2));
  85. mat Bt = B.t();
  86. mat Asubdiagmat_times_Bsubdiagmat =
  87. "\
  88. -0.13495488840 0.00000000000 0.00000000000;\
  89. 0.00000000000 0.05147889540 0.00000000000;\
  90. 0.00000000000 0.00000000000 -0.05796127524;\
  91. ";
  92. mat Bsub_times_Adiagmat =
  93. "\
  94. -0.13495488840 -0.01284312960 -0.10921270491 0.00000000000;\
  95. -0.05821319082 0.05147889540 0.01109632446 0.00000000000;\
  96. 0.02406372274 -0.04401243200 -0.05796127524 0.00000000000;\
  97. ";
  98. mat Adiagmat_times_Bt =
  99. "\
  100. -0.134955 -0.058213 0.024064;\
  101. -0.012843 0.051479 -0.044012;\
  102. -0.109213 0.011096 -0.057961;\
  103. ";
  104. REQUIRE( accu(abs((diagmat(Asub) * diagmat(Bsub)) - Asubdiagmat_times_Bsubdiagmat)) == Approx(0.0) );
  105. REQUIRE( accu(abs((Bsub * diagmat(A)) - Bsub_times_Adiagmat)) == Approx(0.0) );
  106. REQUIRE( accu(abs((B(span::all, span(0,2)) * diagmat(A)) - Bsub_times_Adiagmat)) == Approx(0.0) );
  107. REQUIRE( accu(abs((diagmat(A) * Bt ) - Adiagmat_times_Bt )) == Approx(0.0) );
  108. REQUIRE( accu(abs((diagmat(A) * B.t() ) - Adiagmat_times_Bt )) == Approx(0.0) );
  109. // TODO: Asub and At
  110. }