op_roots_meat.hpp 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup op_roots
  16. //! @{
  17. template<typename T1>
  18. inline
  19. void
  20. op_roots::apply(Mat< std::complex<typename T1::pod_type> >& out, const mtOp<std::complex<typename T1::pod_type>, T1, op_roots>& expr)
  21. {
  22. arma_extra_debug_sigprint();
  23. const bool status = op_roots::apply_direct(out, expr.m);
  24. if(status == false) { arma_stop_runtime_error("roots(): eigen decomposition failed"); }
  25. }
  26. template<typename T1>
  27. inline
  28. bool
  29. op_roots::apply_direct(Mat< std::complex<typename T1::pod_type> >& out, const Base<typename T1::elem_type, T1>& X)
  30. {
  31. arma_extra_debug_sigprint();
  32. typedef std::complex<typename T1::pod_type> out_eT;
  33. const quasi_unwrap<T1> U(X.get_ref());
  34. bool status = false;
  35. if(U.is_alias(out))
  36. {
  37. Mat<out_eT> tmp;
  38. status = op_roots::apply_noalias(tmp, U.M);
  39. out.steal_mem(tmp);
  40. }
  41. else
  42. {
  43. status = op_roots::apply_noalias(out, U.M);
  44. }
  45. if(status == false) { out.soft_reset(); }
  46. return status;
  47. }
  48. template<typename eT>
  49. inline
  50. bool
  51. op_roots::apply_noalias(Mat< std::complex<typename get_pod_type<eT>::result> >& out, const Mat<eT>& X)
  52. {
  53. arma_extra_debug_sigprint();
  54. typedef typename get_pod_type<eT>::result T;
  55. typedef std::complex<typename get_pod_type<eT>::result> out_eT;
  56. arma_debug_check( (X.is_vec() == false), "roots(): given object must be a vector" );
  57. if(X.is_finite() == false) { return false; }
  58. // treat X as a column vector
  59. const Col<eT> Y( const_cast<eT*>(X.memptr()), X.n_elem, false, false);
  60. const T Y_max = (Y.is_empty() == false) ? T(max(abs(Y))) : T(0);
  61. if(Y_max == T(0)) { out.set_size(1,0); return true; }
  62. const uvec indices = find( Y / Y_max );
  63. const uword n_tail_zeros = (indices.n_elem > 0) ? uword( (Y.n_elem-1) - indices[indices.n_elem-1] ) : uword(0);
  64. const Col<eT> Z = Y.subvec( indices[0], indices[indices.n_elem-1] );
  65. if(Z.n_elem >= uword(2))
  66. {
  67. Mat<eT> tmp;
  68. if(Z.n_elem == uword(2))
  69. {
  70. tmp.set_size(1,1);
  71. tmp[0] = -Z[1] / Z[0];
  72. }
  73. else
  74. {
  75. tmp = diagmat(ones< Col<eT> >(Z.n_elem - 2), -1);
  76. tmp.row(0) = strans(-Z.subvec(1, Z.n_elem-1) / Z[0]);
  77. }
  78. Mat<out_eT> junk;
  79. const bool status = auxlib::eig_gen(out, junk, false, tmp);
  80. if(status == false) { return false; }
  81. if(n_tail_zeros > 0)
  82. {
  83. out.resize(out.n_rows + n_tail_zeros, 1);
  84. }
  85. }
  86. else
  87. {
  88. out.zeros(n_tail_zeros,1);
  89. }
  90. return true;
  91. }
  92. //! @}