op_inv_meat.hpp 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup op_inv
  16. //! @{
  17. template<typename T1>
  18. inline
  19. void
  20. op_inv::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_inv>& X)
  21. {
  22. arma_extra_debug_sigprint();
  23. typedef typename T1::elem_type eT;
  24. const strip_diagmat<T1> strip(X.m);
  25. bool status = false;
  26. if(strip.do_diagmat)
  27. {
  28. status = op_inv::apply_diagmat(out, strip.M);
  29. }
  30. else
  31. {
  32. const quasi_unwrap<T1> U(X.m);
  33. if(U.is_alias(out))
  34. {
  35. Mat<eT> tmp;
  36. status = op_inv::apply_noalias(tmp, U.M);
  37. out.steal_mem(tmp);
  38. }
  39. else
  40. {
  41. status = op_inv::apply_noalias(out, U.M);
  42. }
  43. }
  44. if(status == false)
  45. {
  46. out.soft_reset();
  47. arma_stop_runtime_error("inv(): matrix seems singular");
  48. }
  49. }
  50. template<typename eT>
  51. inline
  52. bool
  53. op_inv::apply_noalias(Mat<eT>& out, const Mat<eT>& A)
  54. {
  55. arma_extra_debug_sigprint();
  56. arma_debug_check( (A.n_rows != A.n_cols), "inv(): given matrix must be square sized" );
  57. bool status = false;
  58. if(A.n_rows <= 4)
  59. {
  60. status = auxlib::inv_tiny(out, A);
  61. }
  62. else
  63. if(A.is_diagmat())
  64. {
  65. return op_inv::apply_diagmat(out, A);
  66. }
  67. else
  68. {
  69. const bool is_triu = trimat_helper::is_triu(A);
  70. const bool is_tril = (is_triu) ? false : trimat_helper::is_tril(A);
  71. if(is_triu || is_tril)
  72. {
  73. const uword layout = (is_triu) ? uword(0) : uword(1);
  74. return auxlib::inv_tr(out, A, layout);
  75. }
  76. else
  77. {
  78. #if defined(ARMA_OPTIMISE_SYMPD)
  79. const bool try_sympd = sympd_helper::guess_sympd_anysize(A);
  80. #else
  81. const bool try_sympd = false;
  82. #endif
  83. if(try_sympd)
  84. {
  85. status = auxlib::inv_sympd(out, A);
  86. if(status == false) { arma_extra_debug_print("warning: sympd optimisation failed"); }
  87. }
  88. // auxlib::inv_sympd() may have failed because A isn't really sympd
  89. }
  90. }
  91. if(status == false)
  92. {
  93. status = auxlib::inv(out, A);
  94. }
  95. return status;
  96. }
  97. template<typename T1>
  98. inline
  99. bool
  100. op_inv::apply_diagmat(Mat<typename T1::elem_type>& out, const T1& X)
  101. {
  102. arma_extra_debug_sigprint();
  103. typedef typename T1::elem_type eT;
  104. const diagmat_proxy<T1> A(X);
  105. arma_debug_check( (A.n_rows != A.n_cols), "inv(): given matrix must be square sized" );
  106. const uword N = (std::min)(A.n_rows, A.n_cols);
  107. bool status = true;
  108. if(A.is_alias(out) == false)
  109. {
  110. out.zeros(N,N);
  111. for(uword i=0; i<N; ++i)
  112. {
  113. const eT val = A[i];
  114. out.at(i,i) = eT(1) / val;
  115. status = (val == eT(0)) ? false : status;
  116. }
  117. }
  118. else
  119. {
  120. Mat<eT> tmp(N, N, fill::zeros);
  121. for(uword i=0; i<N; ++i)
  122. {
  123. const eT val = A[i];
  124. tmp.at(i,i) = eT(1) / val;
  125. status = (val == eT(0)) ? false : status;
  126. }
  127. out.steal_mem(tmp);
  128. }
  129. return status;
  130. }
  131. template<typename T1>
  132. inline
  133. void
  134. op_inv_tr::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_inv_tr>& X)
  135. {
  136. arma_extra_debug_sigprint();
  137. const bool status = auxlib::inv_tr(out, X.m, X.aux_uword_a);
  138. if(status == false)
  139. {
  140. out.soft_reset();
  141. arma_stop_runtime_error("inv(): matrix seems singular");
  142. }
  143. }
  144. template<typename T1>
  145. inline
  146. void
  147. op_inv_sympd::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_inv_sympd>& X)
  148. {
  149. arma_extra_debug_sigprint();
  150. const bool status = auxlib::inv_sympd(out, X.m);
  151. if(status == false)
  152. {
  153. out.soft_reset();
  154. arma_stop_runtime_error("inv_sympd(): matrix is singular or not positive definite");
  155. }
  156. }
  157. //! @}