glue_mvnrnd_meat.hpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup glue_mvnrnd
  16. //! @{
  17. // implementation based on:
  18. // James E. Gentle.
  19. // Generation of Random Numbers.
  20. // Computational Statistics, pp. 305-331, 2009.
  21. // http://dx.doi.org/10.1007/978-0-387-98144-4_7
  22. template<typename T1, typename T2>
  23. inline
  24. void
  25. glue_mvnrnd_vec::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_mvnrnd_vec>& expr)
  26. {
  27. arma_extra_debug_sigprint();
  28. const bool status = glue_mvnrnd::apply_direct(out, expr.A, expr.B, uword(1));
  29. if(status == false)
  30. {
  31. arma_stop_runtime_error("mvnrnd(): given covariance matrix is not symmetric positive semi-definite");
  32. }
  33. }
  34. template<typename T1, typename T2>
  35. inline
  36. void
  37. glue_mvnrnd::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_mvnrnd>& expr)
  38. {
  39. arma_extra_debug_sigprint();
  40. const bool status = glue_mvnrnd::apply_direct(out, expr.A, expr.B, expr.aux_uword);
  41. if(status == false)
  42. {
  43. arma_stop_runtime_error("mvnrnd(): given covariance matrix is not symmetric positive semi-definite");
  44. }
  45. }
  46. template<typename T1, typename T2>
  47. inline
  48. bool
  49. glue_mvnrnd::apply_direct(Mat<typename T1::elem_type>& out, const Base<typename T1::elem_type,T1>& M, const Base<typename T1::elem_type,T2>& C, const uword N)
  50. {
  51. arma_extra_debug_sigprint();
  52. typedef typename T1::elem_type eT;
  53. const quasi_unwrap<T1> UM(M.get_ref());
  54. const quasi_unwrap<T2> UC(C.get_ref());
  55. arma_debug_check( (UM.M.is_colvec() == false) && (UM.M.is_empty() == false), "mvnrnd(): given mean must be a column vector" );
  56. arma_debug_check( (UC.M.is_square() == false), "mvnrnd(): given covariance matrix must be square sized" );
  57. arma_debug_check( (UM.M.n_rows != UC.M.n_rows), "mvnrnd(): number of rows in given mean vector and covariance matrix must match" );
  58. if( UM.M.is_empty() || UC.M.is_empty() )
  59. {
  60. out.set_size(0,N);
  61. return true;
  62. }
  63. // if(auxlib::rudimentary_sym_check(UC.M) == false)
  64. // {
  65. // arma_debug_warn("mvnrnd(): given matrix is not symmetric");
  66. // return false;
  67. // }
  68. if((arma_config::debug) && (auxlib::rudimentary_sym_check(UC.M) == false))
  69. {
  70. arma_debug_warn("mvnrnd(): given matrix is not symmetric");
  71. }
  72. bool status = false;
  73. if(UM.is_alias(out) || UC.is_alias(out))
  74. {
  75. Mat<eT> tmp;
  76. status = glue_mvnrnd::apply_noalias(tmp, UM.M, UC.M, N);
  77. out.steal_mem(tmp);
  78. }
  79. else
  80. {
  81. status = glue_mvnrnd::apply_noalias(out, UM.M, UC.M, N);
  82. }
  83. if(status == false) { out.soft_reset(); }
  84. return status;
  85. }
  86. template<typename eT>
  87. inline
  88. bool
  89. glue_mvnrnd::apply_noalias(Mat<eT>& out, const Mat<eT>& M, const Mat<eT>& C, const uword N)
  90. {
  91. arma_extra_debug_sigprint();
  92. Mat<eT> D;
  93. const bool chol_status = op_chol::apply_direct(D, C, 1); // '1' means "lower triangular"
  94. if(chol_status == false)
  95. {
  96. // C is not symmetric positive definite, so find approximate square root of C
  97. Col<eT> eigval; // NOTE: eT is constrained to be real (ie. float or double) in fn_mvnrnd.hpp
  98. Mat<eT> eigvec;
  99. const bool eig_status = eig_sym_helper(eigval, eigvec, C, 'd', "mvnrnd()");
  100. if(eig_status == false) { return false; }
  101. eT* eigval_mem = eigval.memptr();
  102. const uword eigval_n_elem = eigval.n_elem;
  103. // since we're doing an approximation, tolerate tiny negative eigenvalues
  104. const eT tol = eT(-100) * Datum<eT>::eps * norm(C, "fro");
  105. if(arma_isfinite(tol) == false) { return false; }
  106. for(uword i=0; i<eigval_n_elem; ++i)
  107. {
  108. const eT val = eigval_mem[i];
  109. if( (val < tol) || (arma_isfinite(val) == false) ) { return false; }
  110. }
  111. for(uword i=0; i<eigval_n_elem; ++i) { if(eigval_mem[i] < eT(0)) { eigval_mem[i] = eT(0); } }
  112. Mat<eT> DD = eigvec * diagmat(sqrt(eigval));
  113. D.steal_mem(DD);
  114. }
  115. out = D * randn< Mat<eT> >(M.n_rows, N);
  116. if(N == 1)
  117. {
  118. out += M;
  119. }
  120. else
  121. if(N > 1)
  122. {
  123. out.each_col() += M;
  124. }
  125. return true;
  126. }
  127. //! @}