op_wishrnd_meat.hpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup op_wishrnd
  16. //! @{
  17. // implementation based on:
  18. // Yu-Cheng Ku and Peter Bloomfield.
  19. // Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX.
  20. // Oxmetrics User Conference, 2010.
  21. template<typename T1>
  22. inline
  23. void
  24. op_wishrnd::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_wishrnd>& expr)
  25. {
  26. arma_extra_debug_sigprint();
  27. typedef typename T1::elem_type eT;
  28. const eT df = expr.aux;
  29. const uword mode = expr.aux_uword_a;
  30. const bool status = op_wishrnd::apply_direct(out, expr.m, df, mode);
  31. if(status == false)
  32. {
  33. arma_stop_runtime_error("wishrnd(): given matrix is not symmetric positive definite");
  34. }
  35. }
  36. template<typename T1>
  37. inline
  38. bool
  39. op_wishrnd::apply_direct(Mat<typename T1::elem_type>& out, const Base<typename T1::elem_type,T1>& X, const typename T1::elem_type df, const uword mode)
  40. {
  41. arma_extra_debug_sigprint();
  42. typedef typename T1::elem_type eT;
  43. const quasi_unwrap<T1> U(X.get_ref());
  44. bool status = false;
  45. if(U.is_alias(out))
  46. {
  47. Mat<eT> tmp;
  48. if(mode == 1) { status = op_wishrnd::apply_noalias_mode1(tmp, U.M, df); }
  49. if(mode == 2) { status = op_wishrnd::apply_noalias_mode2(tmp, U.M, df); }
  50. out.steal_mem(tmp);
  51. }
  52. else
  53. {
  54. if(mode == 1) { status = op_wishrnd::apply_noalias_mode1(out, U.M, df); }
  55. if(mode == 2) { status = op_wishrnd::apply_noalias_mode2(out, U.M, df); }
  56. }
  57. if(status == false) { out.soft_reset(); }
  58. return status;
  59. }
  60. template<typename eT>
  61. inline
  62. bool
  63. op_wishrnd::apply_noalias_mode1(Mat<eT>& out, const Mat<eT>& S, const eT df)
  64. {
  65. arma_extra_debug_sigprint();
  66. arma_debug_check( (S.is_square() == false), "wishrnd(): given matrix must be square sized" );
  67. if(S.is_empty()) { out.reset(); return true; }
  68. if(auxlib::rudimentary_sym_check(S) == false) { return false; }
  69. Mat<eT> D;
  70. const bool status = op_chol::apply_direct(D, S, 0);
  71. if(status == false) { return false; }
  72. return op_wishrnd::apply_noalias_mode2(out, D, df);
  73. }
  74. template<typename eT>
  75. inline
  76. bool
  77. op_wishrnd::apply_noalias_mode2(Mat<eT>& out, const Mat<eT>& D, const eT df)
  78. {
  79. arma_extra_debug_sigprint();
  80. #if defined(ARMA_USE_CXX11)
  81. {
  82. arma_debug_check( (df <= eT(0)), "df must be greater than zero" );
  83. arma_debug_check( (D.is_square() == false), "wishrnd(): given matrix must be square sized" );
  84. if(D.is_empty()) { out.reset(); return true; }
  85. const uword N = D.n_rows;
  86. if(df < eT(N))
  87. {
  88. arma_extra_debug_print("simple generator");
  89. const uword df_floor = uword(std::floor(df));
  90. const Mat<eT> tmp = (randn< Mat<eT> >(df_floor, N)) * D;
  91. out = tmp.t() * tmp;
  92. }
  93. else
  94. {
  95. arma_extra_debug_print("standard generator");
  96. op_chi2rnd_varying_df<eT> chi2rnd_generator;
  97. Mat<eT> A(N, N, fill::zeros);
  98. for(uword i=0; i<N; ++i)
  99. {
  100. A.at(i,i) = std::sqrt( chi2rnd_generator(df - eT(i)) );
  101. }
  102. for(uword i=1; i < N; ++i)
  103. {
  104. arma_rng::randn<eT>::fill( A.colptr(i), i );
  105. }
  106. const Mat<eT> tmp = A * D;
  107. A.reset();
  108. out = tmp.t() * tmp;
  109. }
  110. return true;
  111. }
  112. #else
  113. {
  114. arma_ignore(out);
  115. arma_ignore(D);
  116. arma_ignore(df);
  117. arma_stop_logic_error("wishrnd(): C++11 compiler required");
  118. return false;
  119. }
  120. #endif
  121. }
  122. //
  123. template<typename T1>
  124. inline
  125. void
  126. op_iwishrnd::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_iwishrnd>& expr)
  127. {
  128. arma_extra_debug_sigprint();
  129. typedef typename T1::elem_type eT;
  130. const eT df = expr.aux;
  131. const uword mode = expr.aux_uword_a;
  132. const bool status = op_iwishrnd::apply_direct(out, expr.m, df, mode);
  133. if(status == false)
  134. {
  135. arma_stop_runtime_error("iwishrnd(): given matrix is not symmetric positive definite and/or df is too low");
  136. }
  137. }
  138. template<typename T1>
  139. inline
  140. bool
  141. op_iwishrnd::apply_direct(Mat<typename T1::elem_type>& out, const Base<typename T1::elem_type,T1>& X, const typename T1::elem_type df, const uword mode)
  142. {
  143. arma_extra_debug_sigprint();
  144. typedef typename T1::elem_type eT;
  145. const quasi_unwrap<T1> U(X.get_ref());
  146. bool status = false;
  147. if(U.is_alias(out))
  148. {
  149. Mat<eT> tmp;
  150. if(mode == 1) { status = op_iwishrnd::apply_noalias_mode1(tmp, U.M, df); }
  151. if(mode == 2) { status = op_iwishrnd::apply_noalias_mode2(tmp, U.M, df); }
  152. out.steal_mem(tmp);
  153. }
  154. else
  155. {
  156. if(mode == 1) { status = op_iwishrnd::apply_noalias_mode1(out, U.M, df); }
  157. if(mode == 2) { status = op_iwishrnd::apply_noalias_mode2(out, U.M, df); }
  158. }
  159. if(status == false) { out.soft_reset(); }
  160. return status;
  161. }
  162. template<typename eT>
  163. inline
  164. bool
  165. op_iwishrnd::apply_noalias_mode1(Mat<eT>& out, const Mat<eT>& T, const eT df)
  166. {
  167. arma_extra_debug_sigprint();
  168. arma_debug_check( (T.is_square() == false), "iwishrnd(): given matrix must be square sized" );
  169. if(T.is_empty()) { out.reset(); return true; }
  170. if(auxlib::rudimentary_sym_check(T) == false) { return false; }
  171. Mat<eT> Tinv;
  172. Mat<eT> Dinv;
  173. const bool inv_status = auxlib::inv_sympd(Tinv, T);
  174. if(inv_status == false) { return false; }
  175. const bool chol_status = op_chol::apply_direct(Dinv, Tinv, 0);
  176. if(chol_status == false) { return false; }
  177. return op_iwishrnd::apply_noalias_mode2(out, Dinv, df);
  178. }
  179. template<typename eT>
  180. inline
  181. bool
  182. op_iwishrnd::apply_noalias_mode2(Mat<eT>& out, const Mat<eT>& Dinv, const eT df)
  183. {
  184. arma_extra_debug_sigprint();
  185. #if defined(ARMA_USE_CXX11)
  186. {
  187. arma_debug_check( (df <= eT(0)), "df must be greater than zero" );
  188. arma_debug_check( (Dinv.is_square() == false), "iwishrnd(): given matrix must be square sized" );
  189. if(Dinv.is_empty()) { out.reset(); return true; }
  190. Mat<eT> tmp;
  191. const bool wishrnd_status = op_wishrnd::apply_noalias_mode2(tmp, Dinv, df);
  192. if(wishrnd_status == false) { return false; }
  193. const bool inv_status1 = auxlib::inv_sympd(out, tmp);
  194. const bool inv_status2 = (inv_status1) ? bool(true) : bool(auxlib::inv(out, tmp));
  195. if(inv_status2 == false) { return false; }
  196. return true;
  197. }
  198. #else
  199. {
  200. arma_ignore(out);
  201. arma_ignore(Dinv);
  202. arma_ignore(df);
  203. arma_stop_logic_error("iwishrnd(): C++11 compiler required");
  204. return false;
  205. }
  206. #endif
  207. }
  208. //! @}