arma_rng_cxx98.hpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup arma_rng_cxx98
  16. //! @{
  17. class arma_rng_cxx98
  18. {
  19. public:
  20. typedef unsigned int seed_type;
  21. inline static void set_seed(const seed_type val);
  22. arma_inline static int randi_val();
  23. arma_inline static double randu_val();
  24. inline static double randn_val();
  25. template<typename eT>
  26. inline static void randn_dual_val(eT& out1, eT& out2);
  27. template<typename eT>
  28. inline static void randi_fill(eT* mem, const uword N, const int a, const int b);
  29. inline static int randi_max_val();
  30. };
  31. inline
  32. void
  33. arma_rng_cxx98::set_seed(const arma_rng_cxx98::seed_type val)
  34. {
  35. std::srand(val);
  36. }
  37. arma_inline
  38. int
  39. arma_rng_cxx98::randi_val()
  40. {
  41. #if (RAND_MAX == 32767)
  42. {
  43. // NOTE: this is a better-than-nothing solution
  44. // NOTE: see also arma_rng_cxx98::randi_max_val()
  45. u32 val1 = u32(std::rand());
  46. u32 val2 = u32(std::rand());
  47. val1 <<= 15;
  48. return (val1 | val2);
  49. }
  50. #else
  51. {
  52. return std::rand();
  53. }
  54. #endif
  55. }
  56. arma_inline
  57. double
  58. arma_rng_cxx98::randu_val()
  59. {
  60. return double( double(randi_val()) * ( double(1) / double(randi_max_val()) ) );
  61. }
  62. inline
  63. double
  64. arma_rng_cxx98::randn_val()
  65. {
  66. // polar form of the Box-Muller transformation:
  67. // http://en.wikipedia.org/wiki/Box-Muller_transformation
  68. // http://en.wikipedia.org/wiki/Marsaglia_polar_method
  69. double tmp1 = double(0);
  70. double tmp2 = double(0);
  71. double w = double(0);
  72. do
  73. {
  74. tmp1 = double(2) * double(randi_val()) * (double(1) / double(randi_max_val())) - double(1);
  75. tmp2 = double(2) * double(randi_val()) * (double(1) / double(randi_max_val())) - double(1);
  76. w = tmp1*tmp1 + tmp2*tmp2;
  77. }
  78. while ( w >= double(1) );
  79. return double( tmp1 * std::sqrt( (double(-2) * std::log(w)) / w) );
  80. }
  81. template<typename eT>
  82. inline
  83. void
  84. arma_rng_cxx98::randn_dual_val(eT& out1, eT& out2)
  85. {
  86. // make sure we are internally using at least floats
  87. typedef typename promote_type<eT,float>::result eTp;
  88. eTp tmp1 = eTp(0);
  89. eTp tmp2 = eTp(0);
  90. eTp w = eTp(0);
  91. do
  92. {
  93. tmp1 = eTp(2) * eTp(randi_val()) * (eTp(1) / eTp(randi_max_val())) - eTp(1);
  94. tmp2 = eTp(2) * eTp(randi_val()) * (eTp(1) / eTp(randi_max_val())) - eTp(1);
  95. w = tmp1*tmp1 + tmp2*tmp2;
  96. }
  97. while ( w >= eTp(1) );
  98. const eTp k = std::sqrt( (eTp(-2) * std::log(w)) / w);
  99. out1 = eT(tmp1*k);
  100. out2 = eT(tmp2*k);
  101. }
  102. template<typename eT>
  103. inline
  104. void
  105. arma_rng_cxx98::randi_fill(eT* mem, const uword N, const int a, const int b)
  106. {
  107. if( (a == 0) && (b == RAND_MAX) )
  108. {
  109. for(uword i=0; i<N; ++i)
  110. {
  111. mem[i] = eT(std::rand());
  112. }
  113. }
  114. else
  115. {
  116. const uword length = uword(b - a + 1);
  117. const double scale = double(length) / double(randi_max_val());
  118. for(uword i=0; i<N; ++i)
  119. {
  120. mem[i] = eT((std::min)( b, (int( double(randi_val()) * scale ) + a) ));
  121. }
  122. }
  123. }
  124. inline
  125. int
  126. arma_rng_cxx98::randi_max_val()
  127. {
  128. #if (RAND_MAX == 32767)
  129. return ( (32767 << 15) + 32767);
  130. #else
  131. return RAND_MAX;
  132. #endif
  133. }
  134. //! @}