arma_rng_cxx11.hpp 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup arma_rng_cxx11
  16. //! @{
  17. #if defined(ARMA_USE_CXX11)
  18. class arma_rng_cxx11
  19. {
  20. public:
  21. typedef std::mt19937_64::result_type seed_type;
  22. inline void set_seed(const seed_type val);
  23. arma_inline int randi_val();
  24. arma_inline double randu_val();
  25. arma_inline double randn_val();
  26. template<typename eT>
  27. arma_inline void randn_dual_val(eT& out1, eT& out2);
  28. template<typename eT>
  29. inline void randi_fill(eT* mem, const uword N, const int a, const int b);
  30. inline static int randi_max_val();
  31. template<typename eT>
  32. inline void randg_fill_simple(eT* mem, const uword N, const double a, const double b);
  33. template<typename eT>
  34. inline void randg_fill(eT* mem, const uword N, const double a, const double b);
  35. private:
  36. arma_aligned std::mt19937_64 engine; // typedef for std::mersenne_twister_engine with preset parameters
  37. arma_aligned std::uniform_int_distribution<int> i_distr; // by default uses a=0, b=std::numeric_limits<int>::max()
  38. arma_aligned std::uniform_real_distribution<double> u_distr; // by default uses [0,1) interval
  39. arma_aligned std::normal_distribution<double> n_distr; // by default uses mean=0.0 and stddev=1.0
  40. };
  41. inline
  42. void
  43. arma_rng_cxx11::set_seed(const arma_rng_cxx11::seed_type val)
  44. {
  45. engine.seed(val);
  46. i_distr.reset();
  47. u_distr.reset();
  48. n_distr.reset();
  49. }
  50. arma_inline
  51. int
  52. arma_rng_cxx11::randi_val()
  53. {
  54. return i_distr(engine);
  55. }
  56. arma_inline
  57. double
  58. arma_rng_cxx11::randu_val()
  59. {
  60. return u_distr(engine);
  61. }
  62. arma_inline
  63. double
  64. arma_rng_cxx11::randn_val()
  65. {
  66. return n_distr(engine);
  67. }
  68. template<typename eT>
  69. arma_inline
  70. void
  71. arma_rng_cxx11::randn_dual_val(eT& out1, eT& out2)
  72. {
  73. out1 = eT( n_distr(engine) );
  74. out2 = eT( n_distr(engine) );
  75. }
  76. template<typename eT>
  77. inline
  78. void
  79. arma_rng_cxx11::randi_fill(eT* mem, const uword N, const int a, const int b)
  80. {
  81. std::uniform_int_distribution<int> local_i_distr(a, b);
  82. for(uword i=0; i<N; ++i)
  83. {
  84. mem[i] = eT(local_i_distr(engine));
  85. }
  86. }
  87. inline
  88. int
  89. arma_rng_cxx11::randi_max_val()
  90. {
  91. return std::numeric_limits<int>::max();
  92. }
  93. template<typename eT>
  94. inline
  95. void
  96. arma_rng_cxx11::randg_fill_simple(eT* mem, const uword N, const double a, const double b)
  97. {
  98. std::gamma_distribution<double> g_distr(a,b);
  99. for(uword i=0; i<N; ++i)
  100. {
  101. mem[i] = eT(g_distr(engine));
  102. }
  103. }
  104. template<typename eT>
  105. inline
  106. void
  107. arma_rng_cxx11::randg_fill(eT* mem, const uword N, const double a, const double b)
  108. {
  109. #if defined(ARMA_USE_OPENMP)
  110. {
  111. if((N < 512) || omp_in_parallel()) { (*this).randg_fill_simple(mem, N, a, b); return; }
  112. typedef std::mt19937_64 motor_type;
  113. typedef std::mt19937_64::result_type ovum_type;
  114. typedef std::gamma_distribution<double> distr_type;
  115. const uword n_threads = uword( mp_thread_limit::get() );
  116. std::vector<motor_type> g_motor(n_threads);
  117. std::vector<distr_type> g_distr(n_threads);
  118. const distr_type g_distr_base(a,b);
  119. for(uword t=0; t < n_threads; ++t)
  120. {
  121. motor_type& g_motor_t = g_motor[t];
  122. distr_type& g_distr_t = g_distr[t];
  123. g_motor_t.seed( ovum_type(t) + ovum_type((*this).randi_val()) );
  124. g_distr_t.param( g_distr_base.param() );
  125. }
  126. const uword chunk_size = N / n_threads;
  127. #pragma omp parallel for schedule(static) num_threads(int(n_threads))
  128. for(uword t=0; t < n_threads; ++t)
  129. {
  130. const uword start = (t+0) * chunk_size;
  131. const uword endp1 = (t+1) * chunk_size;
  132. motor_type& g_motor_t = g_motor[t];
  133. distr_type& g_distr_t = g_distr[t];
  134. for(uword i=start; i < endp1; ++i) { mem[i] = eT( g_distr_t(g_motor_t)); }
  135. }
  136. motor_type& g_motor_0 = g_motor[0];
  137. distr_type& g_distr_0 = g_distr[0];
  138. for(uword i=(n_threads*chunk_size); i < N; ++i) { mem[i] = eT( g_distr_0(g_motor_0)); }
  139. }
  140. #else
  141. {
  142. (*this).randg_fill_simple(mem, N, a, b);
  143. }
  144. #endif
  145. }
  146. #endif
  147. //! @}