arma_rng.hpp 12 KB


  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup arma_rng
  16. //! @{
  17. #include "stdlib.h"
  18. #if defined(ARMA_RNG_ALT)
  19. #undef ARMA_USE_EXTERN_CXX11_RNG
  20. #endif
  21. #if !defined(ARMA_USE_CXX11)
  22. #undef ARMA_USE_EXTERN_CXX11_RNG
  23. #endif
  24. #if defined(ARMA_USE_EXTERN_CXX11_RNG)
  25. extern thread_local arma_rng_cxx11 arma_rng_cxx11_instance;
  26. // namespace { thread_local arma_rng_cxx11 arma_rng_cxx11_instance; }
  27. #endif
  28. class arma_rng
  29. {
  30. public:
  31. #if defined(ARMA_RNG_ALT)
  32. typedef arma_rng_alt::seed_type seed_type;
  33. #elif defined(ARMA_USE_EXTERN_CXX11_RNG)
  34. typedef arma_rng_cxx11::seed_type seed_type;
  35. #else
  36. typedef arma_rng_cxx98::seed_type seed_type;
  37. #endif
  38. #if defined(ARMA_RNG_ALT)
  39. static const int rng_method = 2;
  40. #elif defined(ARMA_USE_EXTERN_CXX11_RNG)
  41. static const int rng_method = 1;
  42. #else
  43. static const int rng_method = 0;
  44. #endif
  45. inline static void set_seed(const seed_type val);
  46. inline static void set_seed_random();
  47. template<typename eT> struct randi;
  48. template<typename eT> struct randu;
  49. template<typename eT> struct randn;
  50. };
  51. inline
  52. void
  53. arma_rng::set_seed(const arma_rng::seed_type val)
  54. {
  55. #if defined(ARMA_RNG_ALT)
  56. {
  57. arma_rng_alt::set_seed(val);
  58. }
  59. #elif defined(ARMA_USE_EXTERN_CXX11_RNG)
  60. {
  61. arma_rng_cxx11_instance.set_seed(val);
  62. }
  63. #else
  64. {
  65. arma_rng_cxx98::set_seed(val);
  66. }
  67. #endif
  68. }
  69. arma_cold
  70. inline
  71. void
  72. arma_rng::set_seed_random()
  73. {
  74. seed_type seed1 = seed_type(0);
  75. seed_type seed2 = seed_type(0);
  76. seed_type seed3 = seed_type(0);
  77. seed_type seed4 = seed_type(0);
  78. seed_type seed5 = seed_type(0);
  79. bool have_seed = false;
  80. #if defined(ARMA_USE_CXX11)
  81. {
  82. try
  83. {
  84. std::random_device rd;
  85. if(rd.entropy() > double(0)) { seed1 = static_cast<seed_type>( rd() ); }
  86. if(seed1 != seed_type(0)) { have_seed = true; }
  87. }
  88. catch(...) {}
  89. }
  90. #endif
  91. if(have_seed == false)
  92. {
  93. try
  94. {
  95. union
  96. {
  97. seed_type a;
  98. unsigned char b[sizeof(seed_type)];
  99. } tmp;
  100. tmp.a = seed_type(0);
  101. std::ifstream f("/dev/urandom", std::ifstream::binary);
  102. if(f.good()) { f.read((char*)(&(tmp.b[0])), sizeof(seed_type)); }
  103. if(f.good())
  104. {
  105. seed2 = tmp.a;
  106. if(seed2 != seed_type(0)) { have_seed = true; }
  107. }
  108. }
  109. catch(...) {}
  110. }
  111. if(have_seed == false)
  112. {
  113. // get better-than-nothing seeds in case reading /dev/urandom failed
  114. #if defined(ARMA_HAVE_GETTIMEOFDAY)
  115. {
  116. struct timeval posix_time;
  117. gettimeofday(&posix_time, 0);
  118. seed3 = static_cast<seed_type>(posix_time.tv_usec);
  119. }
  120. #endif
  121. seed4 = static_cast<seed_type>( std::time(NULL) & 0xFFFF );
  122. union
  123. {
  124. uword* a;
  125. unsigned char b[sizeof(uword*)];
  126. } tmp;
  127. tmp.a = (uword*)malloc(sizeof(uword));
  128. if(tmp.a != NULL)
  129. {
  130. for(size_t i=0; i<sizeof(uword*); ++i) { seed5 += seed_type(tmp.b[i]); }
  131. free(tmp.a);
  132. }
  133. }
  134. arma_rng::set_seed( seed1 + seed2 + seed3 + seed4 + seed5 );
  135. }
  136. template<typename eT>
  137. struct arma_rng::randi
  138. {
  139. arma_inline
  140. operator eT ()
  141. {
  142. #if defined(ARMA_RNG_ALT)
  143. {
  144. return eT( arma_rng_alt::randi_val() );
  145. }
  146. #elif defined(ARMA_USE_EXTERN_CXX11_RNG)
  147. {
  148. return eT( arma_rng_cxx11_instance.randi_val() );
  149. }
  150. #else
  151. {
  152. return eT( arma_rng_cxx98::randi_val() );
  153. }
  154. #endif
  155. }
  156. inline
  157. static
  158. int
  159. max_val()
  160. {
  161. #if defined(ARMA_RNG_ALT)
  162. {
  163. return arma_rng_alt::randi_max_val();
  164. }
  165. #elif defined(ARMA_USE_EXTERN_CXX11_RNG)
  166. {
  167. return arma_rng_cxx11::randi_max_val();
  168. }
  169. #else
  170. {
  171. return arma_rng_cxx98::randi_max_val();
  172. }
  173. #endif
  174. }
  175. inline
  176. static
  177. void
  178. fill(eT* mem, const uword N, const int a, const int b)
  179. {
  180. #if defined(ARMA_RNG_ALT)
  181. {
  182. arma_rng_alt::randi_fill(mem, N, a, b);
  183. }
  184. #elif defined(ARMA_USE_EXTERN_CXX11_RNG)
  185. {
  186. arma_rng_cxx11_instance.randi_fill(mem, N, a, b);
  187. }
  188. #else
  189. {
  190. arma_rng_cxx98::randi_fill(mem, N, a, b);
  191. }
  192. #endif
  193. }
  194. };
  195. template<typename eT>
  196. struct arma_rng::randu
  197. {
  198. arma_inline
  199. operator eT ()
  200. {
  201. #if defined(ARMA_RNG_ALT)
  202. {
  203. return eT( arma_rng_alt::randu_val() );
  204. }
  205. #elif defined(ARMA_USE_EXTERN_CXX11_RNG)
  206. {
  207. return eT( arma_rng_cxx11_instance.randu_val() );
  208. }
  209. #else
  210. {
  211. return eT( arma_rng_cxx98::randu_val() );
  212. }
  213. #endif
  214. }
  215. inline
  216. static
  217. void
  218. fill(eT* mem, const uword N)
  219. {
  220. uword j;
  221. for(j=1; j < N; j+=2)
  222. {
  223. const eT tmp_i = eT( arma_rng::randu<eT>() );
  224. const eT tmp_j = eT( arma_rng::randu<eT>() );
  225. (*mem) = tmp_i; mem++;
  226. (*mem) = tmp_j; mem++;
  227. }
  228. if((j-1) < N)
  229. {
  230. (*mem) = eT( arma_rng::randu<eT>() );
  231. }
  232. }
  233. };
  234. template<typename T>
  235. struct arma_rng::randu< std::complex<T> >
  236. {
  237. arma_inline
  238. operator std::complex<T> ()
  239. {
  240. const T a = T( arma_rng::randu<T>() );
  241. const T b = T( arma_rng::randu<T>() );
  242. return std::complex<T>(a, b);
  243. }
  244. inline
  245. static
  246. void
  247. fill(std::complex<T>* mem, const uword N)
  248. {
  249. for(uword i=0; i < N; ++i)
  250. {
  251. const T a = T( arma_rng::randu<T>() );
  252. const T b = T( arma_rng::randu<T>() );
  253. mem[i] = std::complex<T>(a, b);
  254. }
  255. }
  256. };
  257. template<typename eT>
  258. struct arma_rng::randn
  259. {
  260. inline
  261. operator eT () const
  262. {
  263. #if defined(ARMA_RNG_ALT)
  264. {
  265. return eT( arma_rng_alt::randn_val() );
  266. }
  267. #elif defined(ARMA_USE_EXTERN_CXX11_RNG)
  268. {
  269. return eT( arma_rng_cxx11_instance.randn_val() );
  270. }
  271. #else
  272. {
  273. return eT( arma_rng_cxx98::randn_val() );
  274. }
  275. #endif
  276. }
  277. inline
  278. static
  279. void
  280. dual_val(eT& out1, eT& out2)
  281. {
  282. #if defined(ARMA_RNG_ALT)
  283. {
  284. arma_rng_alt::randn_dual_val(out1, out2);
  285. }
  286. #elif defined(ARMA_USE_EXTERN_CXX11_RNG)
  287. {
  288. arma_rng_cxx11_instance.randn_dual_val(out1, out2);
  289. }
  290. #else
  291. {
  292. arma_rng_cxx98::randn_dual_val(out1, out2);
  293. }
  294. #endif
  295. }
  296. inline
  297. static
  298. void
  299. fill_simple(eT* mem, const uword N)
  300. {
  301. uword i, j;
  302. for(i=0, j=1; j < N; i+=2, j+=2)
  303. {
  304. arma_rng::randn<eT>::dual_val( mem[i], mem[j] );
  305. }
  306. if(i < N)
  307. {
  308. mem[i] = eT( arma_rng::randn<eT>() );
  309. }
  310. }
  311. inline
  312. static
  313. void
  314. fill(eT* mem, const uword N)
  315. {
  316. #if defined(ARMA_USE_CXX11) && defined(ARMA_USE_OPENMP)
  317. {
  318. if((N < 1024) || omp_in_parallel()) { arma_rng::randn<eT>::fill_simple(mem, N); return; }
  319. typedef std::mt19937_64::result_type seed_type;
  320. const uword n_threads = uword( mp_thread_limit::get() );
  321. std::vector< std::mt19937_64 > engine(n_threads);
  322. std::vector< std::normal_distribution<double> > distr(n_threads);
  323. for(uword t=0; t < n_threads; ++t)
  324. {
  325. std::mt19937_64& t_engine = engine[t];
  326. t_engine.seed( seed_type(t) + seed_type(arma_rng::randi<seed_type>()) );
  327. }
  328. const uword chunk_size = N / n_threads;
  329. #pragma omp parallel for schedule(static) num_threads(int(n_threads))
  330. for(uword t=0; t < n_threads; ++t)
  331. {
  332. const uword start = (t+0) * chunk_size;
  333. const uword endp1 = (t+1) * chunk_size;
  334. std::mt19937_64& t_engine = engine[t];
  335. std::normal_distribution<double>& t_distr = distr[t];
  336. for(uword i=start; i < endp1; ++i) { mem[i] = eT( t_distr(t_engine)); }
  337. }
  338. std::mt19937_64& t0_engine = engine[0];
  339. std::normal_distribution<double>& t0_distr = distr[0];
  340. for(uword i=(n_threads*chunk_size); i < N; ++i) { mem[i] = eT( t0_distr(t0_engine)); }
  341. }
  342. #else
  343. {
  344. arma_rng::randn<eT>::fill_simple(mem, N);
  345. }
  346. #endif
  347. }
  348. };
  349. template<typename T>
  350. struct arma_rng::randn< std::complex<T> >
  351. {
  352. inline
  353. operator std::complex<T> () const
  354. {
  355. #if defined(_MSC_VER)
  356. // attempt at workaround for MSVC bug
  357. // does MS even test their so-called compilers before release?
  358. T a;
  359. T b;
  360. #else
  361. T a(0);
  362. T b(0);
  363. #endif
  364. arma_rng::randn<T>::dual_val(a, b);
  365. return std::complex<T>(a, b);
  366. }
  367. inline
  368. static
  369. void
  370. dual_val(std::complex<T>& out1, std::complex<T>& out2)
  371. {
  372. #if defined(_MSC_VER)
  373. T a;
  374. T b;
  375. #else
  376. T a(0);
  377. T b(0);
  378. #endif
  379. arma_rng::randn<T>::dual_val(a,b);
  380. out1 = std::complex<T>(a,b);
  381. arma_rng::randn<T>::dual_val(a,b);
  382. out2 = std::complex<T>(a,b);
  383. }
  384. inline
  385. static
  386. void
  387. fill_simple(std::complex<T>* mem, const uword N)
  388. {
  389. for(uword i=0; i < N; ++i)
  390. {
  391. mem[i] = std::complex<T>( arma_rng::randn< std::complex<T> >() );
  392. }
  393. }
  394. inline
  395. static
  396. void
  397. fill(std::complex<T>* mem, const uword N)
  398. {
  399. #if defined(ARMA_USE_CXX11) && defined(ARMA_USE_OPENMP)
  400. {
  401. if((N < 512) || omp_in_parallel()) { arma_rng::randn< std::complex<T> >::fill_simple(mem, N); return; }
  402. typedef std::mt19937_64::result_type seed_type;
  403. const uword n_threads = uword( mp_thread_limit::get() );
  404. std::vector< std::mt19937_64 > engine(n_threads);
  405. std::vector< std::normal_distribution<double> > distr(n_threads);
  406. for(uword t=0; t < n_threads; ++t)
  407. {
  408. std::mt19937_64& t_engine = engine[t];
  409. t_engine.seed( seed_type(t) + seed_type(arma_rng::randi<seed_type>()) );
  410. }
  411. const uword chunk_size = N / n_threads;
  412. #pragma omp parallel for schedule(static) num_threads(int(n_threads))
  413. for(uword t=0; t < n_threads; ++t)
  414. {
  415. const uword start = (t+0) * chunk_size;
  416. const uword endp1 = (t+1) * chunk_size;
  417. std::mt19937_64& t_engine = engine[t];
  418. std::normal_distribution<double>& t_distr = distr[t];
  419. for(uword i=start; i < endp1; ++i)
  420. {
  421. const T val1 = T( t_distr(t_engine) );
  422. const T val2 = T( t_distr(t_engine) );
  423. mem[i] = std::complex<T>(val1, val2);
  424. }
  425. }
  426. std::mt19937_64& t0_engine = engine[0];
  427. std::normal_distribution<double>& t0_distr = distr[0];
  428. for(uword i=(n_threads*chunk_size); i < N; ++i)
  429. {
  430. const T val1 = T( t0_distr(t0_engine) );
  431. const T val2 = T( t0_distr(t0_engine) );
  432. mem[i] = std::complex<T>(val1, val2);
  433. }
  434. }
  435. #else
  436. {
  437. arma_rng::randn< std::complex<T> >::fill_simple(mem, N);
  438. }
  439. #endif
  440. }
  441. };
  442. //! @}