fn_misc.hpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup fn_misc
  16. //! @{
  17. template<typename out_type>
  18. arma_warn_unused
  19. inline
  20. typename
  21. enable_if2
  22. <
  23. is_Mat<out_type>::value,
  24. out_type
  25. >::result
  26. linspace
  27. (
  28. const typename out_type::pod_type start,
  29. const typename out_type::pod_type end,
  30. const uword num = 100u
  31. )
  32. {
  33. arma_extra_debug_sigprint();
  34. typedef typename out_type::elem_type eT;
  35. typedef typename out_type::pod_type T;
  36. out_type x;
  37. if(num == 1)
  38. {
  39. x.set_size(1);
  40. x[0] = eT(end);
  41. }
  42. else
  43. if(num >= 2)
  44. {
  45. x.set_size(num);
  46. eT* x_mem = x.memptr();
  47. const uword num_m1 = num - 1;
  48. if(is_non_integral<T>::value == true)
  49. {
  50. const T delta = (end-start)/T(num_m1);
  51. for(uword i=0; i<num_m1; ++i)
  52. {
  53. x_mem[i] = eT(start + i*delta);
  54. }
  55. x_mem[num_m1] = eT(end);
  56. }
  57. else
  58. {
  59. const double delta = (end >= start) ? double(end-start)/double(num_m1) : -double(start-end)/double(num_m1);
  60. for(uword i=0; i<num_m1; ++i)
  61. {
  62. x_mem[i] = eT(double(start) + i*delta);
  63. }
  64. x_mem[num_m1] = eT(end);
  65. }
  66. }
  67. return x;
  68. }
  69. arma_warn_unused
  70. inline
  71. vec
  72. linspace(const double start, const double end, const uword num = 100u)
  73. {
  74. arma_extra_debug_sigprint();
  75. return linspace<vec>(start, end, num);
  76. }
  77. template<typename out_type>
  78. arma_warn_unused
  79. inline
  80. typename
  81. enable_if2
  82. <
  83. (is_Mat<out_type>::value && is_real<typename out_type::pod_type>::value),
  84. out_type
  85. >::result
  86. logspace
  87. (
  88. const typename out_type::pod_type A,
  89. const typename out_type::pod_type B,
  90. const uword N = 50u
  91. )
  92. {
  93. arma_extra_debug_sigprint();
  94. typedef typename out_type::elem_type eT;
  95. typedef typename out_type::pod_type T;
  96. out_type x = linspace<out_type>(A,B,N);
  97. const uword n_elem = x.n_elem;
  98. eT* x_mem = x.memptr();
  99. for(uword i=0; i < n_elem; ++i)
  100. {
  101. x_mem[i] = std::pow(T(10), x_mem[i]);
  102. }
  103. return x;
  104. }
  105. arma_warn_unused
  106. inline
  107. vec
  108. logspace(const double A, const double B, const uword N = 50u)
  109. {
  110. arma_extra_debug_sigprint();
  111. return logspace<vec>(A, B, N);
  112. }
  113. //
  114. // log_exp_add
  115. template<typename eT>
  116. arma_warn_unused
  117. inline
  118. typename arma_real_only<eT>::result
  119. log_add_exp(eT log_a, eT log_b)
  120. {
  121. if(log_a < log_b)
  122. {
  123. std::swap(log_a, log_b);
  124. }
  125. const eT negdelta = log_b - log_a;
  126. if( (negdelta < Datum<eT>::log_min) || (arma_isfinite(negdelta) == false) )
  127. {
  128. return log_a;
  129. }
  130. else
  131. {
  132. return (log_a + arma_log1p(std::exp(negdelta)));
  133. }
  134. }
  135. // for compatibility with earlier versions
  136. template<typename eT>
  137. arma_warn_unused
  138. inline
  139. typename arma_real_only<eT>::result
  140. log_add(eT log_a, eT log_b)
  141. {
  142. return log_add_exp(log_a, log_b);
  143. }
  144. //! kept for compatibility with old user code
  145. template<typename eT>
  146. arma_warn_unused
  147. arma_inline
  148. bool
  149. is_finite(const eT x, const typename arma_scalar_only<eT>::result* junk = 0)
  150. {
  151. arma_ignore(junk);
  152. return arma_isfinite(x);
  153. }
  154. //! kept for compatibility with old user code
  155. template<typename T1>
  156. arma_warn_unused
  157. inline
  158. bool
  159. is_finite(const Base<typename T1::elem_type,T1>& X)
  160. {
  161. arma_extra_debug_sigprint();
  162. return X.is_finite();
  163. }
  164. //! kept for compatibility with old user code
  165. template<typename T1>
  166. arma_warn_unused
  167. inline
  168. bool
  169. is_finite(const SpBase<typename T1::elem_type,T1>& X)
  170. {
  171. arma_extra_debug_sigprint();
  172. return X.is_finite();
  173. }
  174. //! kept for compatibility with old user code
  175. template<typename T1>
  176. arma_warn_unused
  177. inline
  178. bool
  179. is_finite(const BaseCube<typename T1::elem_type,T1>& X)
  180. {
  181. arma_extra_debug_sigprint();
  182. return X.is_finite();
  183. }
  184. //! NOTE: don't use this function: it will be removed
  185. template<typename T1>
  186. arma_deprecated
  187. inline
  188. const T1&
  189. sympd(const Base<typename T1::elem_type,T1>& X)
  190. {
  191. arma_extra_debug_sigprint();
  192. arma_debug_warn("sympd() is deprecated and will be removed; change inv(sympd(X)) to inv_sympd(X)");
  193. return X.get_ref();
  194. }
  195. template<typename eT>
  196. inline
  197. void
  198. swap(Mat<eT>& A, Mat<eT>& B)
  199. {
  200. arma_extra_debug_sigprint();
  201. A.swap(B);
  202. }
  203. template<typename eT>
  204. inline
  205. void
  206. swap(Cube<eT>& A, Cube<eT>& B)
  207. {
  208. arma_extra_debug_sigprint();
  209. A.swap(B);
  210. }
  211. arma_warn_unused
  212. inline
  213. uvec
  214. ind2sub(const SizeMat& s, const uword i)
  215. {
  216. arma_extra_debug_sigprint();
  217. const uword s_n_rows = s.n_rows;
  218. arma_debug_check( (i >= (s_n_rows * s.n_cols) ), "ind2sub(): index out of range" );
  219. const uword row = i % s_n_rows;
  220. const uword col = i / s_n_rows;
  221. uvec out(2);
  222. uword* out_mem = out.memptr();
  223. out_mem[0] = row;
  224. out_mem[1] = col;
  225. return out;
  226. }
  227. template<typename T1>
  228. arma_warn_unused
  229. inline
  230. typename enable_if2< (is_arma_type<T1>::value && is_same_type<uword,typename T1::elem_type>::yes), umat >::result
  231. ind2sub(const SizeMat& s, const T1& indices)
  232. {
  233. arma_extra_debug_sigprint();
  234. const uword s_n_rows = s.n_rows;
  235. const uword s_n_elem = s_n_rows * s.n_cols;
  236. const Proxy<T1> P(indices);
  237. const uword P_n_rows = P.get_n_rows();
  238. const uword P_n_cols = P.get_n_cols();
  239. const uword P_n_elem = P.get_n_elem();
  240. const bool P_is_empty = (P_n_elem == 0);
  241. const bool P_is_vec = ((P_n_rows == 1) || (P_n_cols == 1));
  242. arma_debug_check( ((P_is_empty == false) && (P_is_vec == false)), "ind2sub(): parameter 'indices' must be a vector" );
  243. umat out(2,P_n_elem);
  244. if(Proxy<T1>::use_at == false)
  245. {
  246. typename Proxy<T1>::ea_type Pea = P.get_ea();
  247. for(uword count=0; count < P_n_elem; ++count)
  248. {
  249. const uword i = Pea[count];
  250. arma_debug_check( (i >= s_n_elem), "ind2sub(): index out of range" );
  251. const uword row = i % s_n_rows;
  252. const uword col = i / s_n_rows;
  253. uword* out_colptr = out.colptr(count);
  254. out_colptr[0] = row;
  255. out_colptr[1] = col;
  256. }
  257. }
  258. else
  259. {
  260. if(P_n_rows == 1)
  261. {
  262. for(uword count=0; count < P_n_cols; ++count)
  263. {
  264. const uword i = P.at(0,count);
  265. arma_debug_check( (i >= s_n_elem), "ind2sub(): index out of range" );
  266. const uword row = i % s_n_rows;
  267. const uword col = i / s_n_rows;
  268. uword* out_colptr = out.colptr(count);
  269. out_colptr[0] = row;
  270. out_colptr[1] = col;
  271. }
  272. }
  273. else
  274. if(P_n_cols == 1)
  275. {
  276. for(uword count=0; count < P_n_rows; ++count)
  277. {
  278. const uword i = P.at(count,0);
  279. arma_debug_check( (i >= s_n_elem), "ind2sub(): index out of range" );
  280. const uword row = i % s_n_rows;
  281. const uword col = i / s_n_rows;
  282. uword* out_colptr = out.colptr(count);
  283. out_colptr[0] = row;
  284. out_colptr[1] = col;
  285. }
  286. }
  287. }
  288. return out;
  289. }
  290. arma_warn_unused
  291. inline
  292. uvec
  293. ind2sub(const SizeCube& s, const uword i)
  294. {
  295. arma_extra_debug_sigprint();
  296. const uword s_n_rows = s.n_rows;
  297. const uword s_n_elem_slice = s_n_rows * s.n_cols;
  298. arma_debug_check( (i >= (s_n_elem_slice * s.n_slices) ), "ind2sub(): index out of range" );
  299. const uword slice = i / s_n_elem_slice;
  300. const uword j = i - (slice * s_n_elem_slice);
  301. const uword row = j % s_n_rows;
  302. const uword col = j / s_n_rows;
  303. uvec out(3);
  304. uword* out_mem = out.memptr();
  305. out_mem[0] = row;
  306. out_mem[1] = col;
  307. out_mem[2] = slice;
  308. return out;
  309. }
  310. template<typename T1>
  311. arma_warn_unused
  312. inline
  313. typename enable_if2< (is_arma_type<T1>::value && is_same_type<uword,typename T1::elem_type>::yes), umat >::result
  314. ind2sub(const SizeCube& s, const T1& indices)
  315. {
  316. arma_extra_debug_sigprint();
  317. const uword s_n_rows = s.n_rows;
  318. const uword s_n_elem_slice = s_n_rows * s.n_cols;
  319. const uword s_n_elem = s.n_slices * s_n_elem_slice;
  320. const quasi_unwrap<T1> U(indices);
  321. arma_debug_check( ((U.M.is_empty() == false) && (U.M.is_vec() == false)), "ind2sub(): parameter 'indices' must be a vector" );
  322. const uword U_n_elem = U.M.n_elem;
  323. const uword* U_mem = U.M.memptr();
  324. umat out(3,U_n_elem);
  325. for(uword count=0; count < U_n_elem; ++count)
  326. {
  327. const uword i = U_mem[count];
  328. arma_debug_check( (i >= s_n_elem), "ind2sub(): index out of range" );
  329. const uword slice = i / s_n_elem_slice;
  330. const uword j = i - (slice * s_n_elem_slice);
  331. const uword row = j % s_n_rows;
  332. const uword col = j / s_n_rows;
  333. uword* out_colptr = out.colptr(count);
  334. out_colptr[0] = row;
  335. out_colptr[1] = col;
  336. out_colptr[2] = slice;
  337. }
  338. return out;
  339. }
  340. arma_warn_unused
  341. arma_inline
  342. uword
  343. sub2ind(const SizeMat& s, const uword row, const uword col)
  344. {
  345. arma_extra_debug_sigprint();
  346. const uword s_n_rows = s.n_rows;
  347. arma_debug_check( ((row >= s_n_rows) || (col >= s.n_cols)), "sub2ind(): subscript out of range" );
  348. return uword(row + col*s_n_rows);
  349. }
  350. template<typename T1>
  351. arma_warn_unused
  352. inline
  353. uvec
  354. sub2ind(const SizeMat& s, const Base<uword,T1>& subscripts)
  355. {
  356. arma_extra_debug_sigprint();
  357. const uword s_n_rows = s.n_rows;
  358. const uword s_n_cols = s.n_cols;
  359. const quasi_unwrap<T1> U(subscripts.get_ref());
  360. arma_debug_check( (U.M.n_rows != 2), "sub2ind(): matrix of subscripts must have 2 rows" );
  361. const uword U_M_n_cols = U.M.n_cols;
  362. uvec out(U_M_n_cols);
  363. uword* out_mem = out.memptr();
  364. const uword* U_M_mem = U.M.memptr();
  365. for(uword count=0; count < U_M_n_cols; ++count)
  366. {
  367. const uword row = U_M_mem[0];
  368. const uword col = U_M_mem[1];
  369. U_M_mem += 2; // next column
  370. arma_debug_check( ((row >= s_n_rows) || (col >= s_n_cols)), "sub2ind(): subscript out of range" );
  371. out_mem[count] = uword(row + col*s_n_rows);
  372. }
  373. return out;
  374. }
  375. arma_warn_unused
  376. arma_inline
  377. uword
  378. sub2ind(const SizeCube& s, const uword row, const uword col, const uword slice)
  379. {
  380. arma_extra_debug_sigprint();
  381. const uword s_n_rows = s.n_rows;
  382. const uword s_n_cols = s.n_cols;
  383. arma_debug_check( ((row >= s_n_rows) || (col >= s_n_cols) || (slice >= s.n_slices)), "sub2ind(): subscript out of range" );
  384. return uword( (slice * s_n_rows * s_n_cols) + (col * s_n_rows) + row );
  385. }
  386. template<typename T1>
  387. arma_warn_unused
  388. inline
  389. uvec
  390. sub2ind(const SizeCube& s, const Base<uword,T1>& subscripts)
  391. {
  392. arma_extra_debug_sigprint();
  393. const uword s_n_rows = s.n_rows;
  394. const uword s_n_cols = s.n_cols;
  395. const uword s_n_slices = s.n_slices;
  396. const quasi_unwrap<T1> U(subscripts.get_ref());
  397. arma_debug_check( (U.M.n_rows != 3), "sub2ind(): matrix of subscripts must have 3 rows" );
  398. const uword U_M_n_cols = U.M.n_cols;
  399. uvec out(U_M_n_cols);
  400. uword* out_mem = out.memptr();
  401. const uword* U_M_mem = U.M.memptr();
  402. for(uword count=0; count < U_M_n_cols; ++count)
  403. {
  404. const uword row = U_M_mem[0];
  405. const uword col = U_M_mem[1];
  406. const uword slice = U_M_mem[2];
  407. U_M_mem += 3; // next column
  408. arma_debug_check( ((row >= s_n_rows) || (col >= s_n_cols) || (slice >= s_n_slices)), "sub2ind(): subscript out of range" );
  409. out_mem[count] = uword( (slice * s_n_rows * s_n_cols) + (col * s_n_rows) + row );
  410. }
  411. return out;
  412. }
  413. template<typename T1, typename T2>
  414. arma_inline
  415. typename
  416. enable_if2
  417. <
  418. (is_arma_type<T1>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
  419. const Glue<T1, T2, glue_affmul>
  420. >::result
  421. affmul(const T1& A, const T2& B)
  422. {
  423. arma_extra_debug_sigprint();
  424. return Glue<T1, T2, glue_affmul>(A,B);
  425. }
  426. //! @}