fn_normpdf.hpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup fn_normpdf
  16. //! @{
  17. template<typename T1, typename T2, typename T3>
  18. inline
  19. typename enable_if2< (is_real<typename T1::elem_type>::value), void >::result
  20. normpdf_helper(Mat<typename T1::elem_type>& out, const Base<typename T1::elem_type, T1>& X_expr, const Base<typename T1::elem_type, T2>& M_expr, const Base<typename T1::elem_type, T3>& S_expr)
  21. {
  22. arma_extra_debug_sigprint();
  23. typedef typename T1::elem_type eT;
  24. if(Proxy<T1>::use_at || Proxy<T2>::use_at || Proxy<T3>::use_at)
  25. {
  26. const quasi_unwrap<T1> UX(X_expr.get_ref());
  27. const quasi_unwrap<T2> UM(M_expr.get_ref());
  28. const quasi_unwrap<T3> US(S_expr.get_ref());
  29. normpdf_helper(out, UX.M, UM.M, US.M);
  30. return;
  31. }
  32. const Proxy<T1> PX(X_expr.get_ref());
  33. const Proxy<T2> PM(M_expr.get_ref());
  34. const Proxy<T3> PS(S_expr.get_ref());
  35. arma_debug_check( ( (PX.get_n_rows() != PM.get_n_rows()) || (PX.get_n_cols() != PM.get_n_cols()) || (PM.get_n_rows() != PS.get_n_rows()) || (PM.get_n_cols() != PS.get_n_cols()) ), "normpdf(): size mismatch" );
  36. out.set_size(PX.get_n_rows(), PX.get_n_cols());
  37. eT* out_mem = out.memptr();
  38. const uword N = PX.get_n_elem();
  39. typename Proxy<T1>::ea_type X_ea = PX.get_ea();
  40. typename Proxy<T2>::ea_type M_ea = PM.get_ea();
  41. typename Proxy<T3>::ea_type S_ea = PS.get_ea();
  42. const bool use_mp = arma_config::cxx11 && arma_config::openmp && mp_gate<eT,true>::eval(N);
  43. if(use_mp)
  44. {
  45. #if defined(ARMA_USE_OPENMP)
  46. {
  47. const int n_threads = mp_thread_limit::get();
  48. #pragma omp parallel for schedule(static) num_threads(n_threads)
  49. for(uword i=0; i<N; ++i)
  50. {
  51. const eT sigma = S_ea[i];
  52. const eT tmp = (X_ea[i] - M_ea[i]) / sigma;
  53. out_mem[i] = std::exp(eT(-0.5) * (tmp*tmp)) / (sigma * Datum<eT>::sqrt2pi);
  54. }
  55. }
  56. #endif
  57. }
  58. else
  59. {
  60. for(uword i=0; i<N; ++i)
  61. {
  62. const eT sigma = S_ea[i];
  63. const eT tmp = (X_ea[i] - M_ea[i]) / sigma;
  64. out_mem[i] = std::exp(eT(-0.5) * (tmp*tmp)) / (sigma * Datum<eT>::sqrt2pi);
  65. }
  66. }
  67. }
  68. template<typename eT>
  69. inline
  70. arma_warn_unused
  71. typename enable_if2< (is_real<eT>::value), eT >::result
  72. normpdf(const eT x)
  73. {
  74. const eT out = std::exp(eT(-0.5) * (x*x)) / Datum<eT>::sqrt2pi;
  75. return out;
  76. }
  77. template<typename eT>
  78. inline
  79. arma_warn_unused
  80. typename enable_if2< (is_real<eT>::value), eT >::result
  81. normpdf(const eT x, const eT mu, const eT sigma)
  82. {
  83. const eT tmp = (x - mu) / sigma;
  84. const eT out = std::exp(eT(-0.5) * (tmp*tmp)) / (sigma * Datum<eT>::sqrt2pi);
  85. return out;
  86. }
  87. template<typename eT, typename T2, typename T3>
  88. inline
  89. arma_warn_unused
  90. typename enable_if2< (is_real<eT>::value), Mat<eT> >::result
  91. normpdf(const eT x, const Base<eT, T2>& M_expr, const Base<eT, T3>& S_expr)
  92. {
  93. arma_extra_debug_sigprint();
  94. const quasi_unwrap<T2> UM(M_expr.get_ref());
  95. const Mat<eT>& M = UM.M;
  96. Mat<eT> out;
  97. normpdf_helper(out, x*ones< Mat<eT> >(arma::size(M)), M, S_expr.get_ref());
  98. return out;
  99. }
  100. template<typename T1>
  101. inline
  102. arma_warn_unused
  103. typename enable_if2< (is_real<typename T1::elem_type>::value), Mat<typename T1::elem_type> >::result
  104. normpdf(const Base<typename T1::elem_type, T1>& X_expr)
  105. {
  106. arma_extra_debug_sigprint();
  107. typedef typename T1::elem_type eT;
  108. const quasi_unwrap<T1> UX(X_expr.get_ref());
  109. const Mat<eT>& X = UX.M;
  110. Mat<eT> out;
  111. normpdf_helper(out, X, zeros< Mat<eT> >(arma::size(X)), ones< Mat<eT> >(arma::size(X)));
  112. return out;
  113. }
  114. template<typename T1>
  115. inline
  116. arma_warn_unused
  117. typename enable_if2< (is_real<typename T1::elem_type>::value), Mat<typename T1::elem_type> >::result
  118. normpdf(const Base<typename T1::elem_type, T1>& X_expr, const typename T1::elem_type mu, const typename T1::elem_type sigma)
  119. {
  120. arma_extra_debug_sigprint();
  121. typedef typename T1::elem_type eT;
  122. const quasi_unwrap<T1> UX(X_expr.get_ref());
  123. const Mat<eT>& X = UX.M;
  124. Mat<eT> out;
  125. normpdf_helper(out, X, mu*ones< Mat<eT> >(arma::size(X)), sigma*ones< Mat<eT> >(arma::size(X)));
  126. return out;
  127. }
  128. template<typename T1, typename T2, typename T3>
  129. inline
  130. arma_warn_unused
  131. typename enable_if2< (is_real<typename T1::elem_type>::value), Mat<typename T1::elem_type> >::result
  132. normpdf(const Base<typename T1::elem_type, T1>& X_expr, const Base<typename T1::elem_type, T2>& M_expr, const Base<typename T1::elem_type, T3>& S_expr)
  133. {
  134. arma_extra_debug_sigprint();
  135. typedef typename T1::elem_type eT;
  136. Mat<eT> out;
  137. normpdf_helper(out, X_expr.get_ref(), M_expr.get_ref(), S_expr.get_ref());
  138. return out;
  139. }
  140. //! @}