fn_normcdf.hpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup fn_normcdf
  16. //! @{
  17. template<typename T1, typename T2, typename T3>
  18. inline
  19. typename enable_if2< (is_real<typename T1::elem_type>::value), void >::result
  20. normcdf_helper(Mat<typename T1::elem_type>& out, const Base<typename T1::elem_type, T1>& X_expr, const Base<typename T1::elem_type, T2>& M_expr, const Base<typename T1::elem_type, T3>& S_expr)
  21. {
  22. arma_extra_debug_sigprint();
  23. #if !defined(ARMA_USE_CXX11)
  24. {
  25. arma_stop_logic_error("normcdf(): C++11 compiler required");
  26. return;
  27. }
  28. #else
  29. {
  30. typedef typename T1::elem_type eT;
  31. if(Proxy<T1>::use_at || Proxy<T2>::use_at || Proxy<T3>::use_at)
  32. {
  33. const quasi_unwrap<T1> UX(X_expr.get_ref());
  34. const quasi_unwrap<T2> UM(M_expr.get_ref());
  35. const quasi_unwrap<T3> US(S_expr.get_ref());
  36. normcdf_helper(out, UX.M, UM.M, US.M);
  37. return;
  38. }
  39. const Proxy<T1> PX(X_expr.get_ref());
  40. const Proxy<T2> PM(M_expr.get_ref());
  41. const Proxy<T3> PS(S_expr.get_ref());
  42. arma_debug_check( ( (PX.get_n_rows() != PM.get_n_rows()) || (PX.get_n_cols() != PM.get_n_cols()) || (PM.get_n_rows() != PS.get_n_rows()) || (PM.get_n_cols() != PS.get_n_cols()) ), "normcdf(): size mismatch" );
  43. out.set_size(PX.get_n_rows(), PX.get_n_cols());
  44. eT* out_mem = out.memptr();
  45. const uword N = PX.get_n_elem();
  46. typename Proxy<T1>::ea_type X_ea = PX.get_ea();
  47. typename Proxy<T2>::ea_type M_ea = PM.get_ea();
  48. typename Proxy<T3>::ea_type S_ea = PS.get_ea();
  49. const bool use_mp = arma_config::cxx11 && arma_config::openmp && mp_gate<eT,true>::eval(N);
  50. if(use_mp)
  51. {
  52. #if defined(ARMA_USE_OPENMP)
  53. {
  54. const int n_threads = mp_thread_limit::get();
  55. #pragma omp parallel for schedule(static) num_threads(n_threads)
  56. for(uword i=0; i<N; ++i)
  57. {
  58. const eT tmp = (X_ea[i] - M_ea[i]) / (S_ea[i] * (-Datum<eT>::sqrt2));
  59. out_mem[i] = eT(0.5) * std::erfc(tmp);
  60. }
  61. }
  62. #endif
  63. }
  64. else
  65. {
  66. for(uword i=0; i<N; ++i)
  67. {
  68. const eT tmp = (X_ea[i] - M_ea[i]) / (S_ea[i] * (-Datum<eT>::sqrt2));
  69. out_mem[i] = eT(0.5) * std::erfc(tmp);
  70. }
  71. }
  72. }
  73. #endif
  74. }
  75. template<typename eT>
  76. inline
  77. arma_warn_unused
  78. typename enable_if2< (is_real<eT>::value), eT >::result
  79. normcdf(const eT x)
  80. {
  81. #if !defined(ARMA_USE_CXX11)
  82. {
  83. arma_stop_logic_error("normcdf(): C++11 compiler required");
  84. return eT(0);
  85. }
  86. #else
  87. {
  88. const eT out = eT(0.5) * std::erfc( x / (-Datum<eT>::sqrt2) );
  89. return out;
  90. }
  91. #endif
  92. }
  93. template<typename eT>
  94. inline
  95. arma_warn_unused
  96. typename enable_if2< (is_real<eT>::value), eT >::result
  97. normcdf(const eT x, const eT mu, const eT sigma)
  98. {
  99. #if !defined(ARMA_USE_CXX11)
  100. {
  101. arma_stop_logic_error("normcdf(): C++11 compiler required");
  102. return eT(0);
  103. }
  104. #else
  105. {
  106. const eT tmp = (x - mu) / (sigma * (-Datum<eT>::sqrt2));
  107. const eT out = eT(0.5) * std::erfc(tmp);
  108. return out;
  109. }
  110. #endif
  111. }
  112. template<typename eT, typename T2, typename T3>
  113. inline
  114. arma_warn_unused
  115. typename enable_if2< (is_real<eT>::value), Mat<eT> >::result
  116. normcdf(const eT x, const Base<eT, T2>& M_expr, const Base<eT, T3>& S_expr)
  117. {
  118. arma_extra_debug_sigprint();
  119. const quasi_unwrap<T2> UM(M_expr.get_ref());
  120. const Mat<eT>& M = UM.M;
  121. Mat<eT> out;
  122. normcdf_helper(out, x*ones< Mat<eT> >(arma::size(M)), M, S_expr.get_ref());
  123. return out;
  124. }
  125. template<typename T1>
  126. inline
  127. arma_warn_unused
  128. typename enable_if2< (is_real<typename T1::elem_type>::value), Mat<typename T1::elem_type> >::result
  129. normcdf(const Base<typename T1::elem_type, T1>& X_expr)
  130. {
  131. arma_extra_debug_sigprint();
  132. typedef typename T1::elem_type eT;
  133. const quasi_unwrap<T1> UX(X_expr.get_ref());
  134. const Mat<eT>& X = UX.M;
  135. Mat<eT> out;
  136. normcdf_helper(out, X, zeros< Mat<eT> >(arma::size(X)), ones< Mat<eT> >(arma::size(X)));
  137. return out;
  138. }
  139. template<typename T1>
  140. inline
  141. arma_warn_unused
  142. typename enable_if2< (is_real<typename T1::elem_type>::value), Mat<typename T1::elem_type> >::result
  143. normcdf(const Base<typename T1::elem_type, T1>& X_expr, const typename T1::elem_type mu, const typename T1::elem_type sigma)
  144. {
  145. arma_extra_debug_sigprint();
  146. typedef typename T1::elem_type eT;
  147. const quasi_unwrap<T1> UX(X_expr.get_ref());
  148. const Mat<eT>& X = UX.M;
  149. Mat<eT> out;
  150. normcdf_helper(out, X, mu*ones< Mat<eT> >(arma::size(X)), sigma*ones< Mat<eT> >(arma::size(X)));
  151. return out;
  152. }
  153. template<typename T1, typename T2, typename T3>
  154. inline
  155. arma_warn_unused
  156. typename enable_if2< (is_real<typename T1::elem_type>::value), Mat<typename T1::elem_type> >::result
  157. normcdf(const Base<typename T1::elem_type, T1>& X_expr, const Base<typename T1::elem_type, T2>& M_expr, const Base<typename T1::elem_type, T3>& S_expr)
  158. {
  159. arma_extra_debug_sigprint();
  160. typedef typename T1::elem_type eT;
  161. Mat<eT> out;
  162. normcdf_helper(out, X_expr.get_ref(), M_expr.get_ref(), S_expr.get_ref());
  163. return out;
  164. }
  165. //! @}