op_mean_meat.hpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup op_mean
  16. //! @{
  17. template<typename T1>
  18. inline
  19. void
  20. op_mean::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_mean>& in)
  21. {
  22. arma_extra_debug_sigprint();
  23. typedef typename T1::elem_type eT;
  24. const uword dim = in.aux_uword_a;
  25. arma_debug_check( (dim > 1), "mean(): parameter 'dim' must be 0 or 1" );
  26. const Proxy<T1> P(in.m);
  27. if(P.is_alias(out) == false)
  28. {
  29. op_mean::apply_noalias(out, P, dim);
  30. }
  31. else
  32. {
  33. Mat<eT> tmp;
  34. op_mean::apply_noalias(tmp, P, dim);
  35. out.steal_mem(tmp);
  36. }
  37. }
  38. template<typename T1>
  39. inline
  40. void
  41. op_mean::apply_noalias(Mat<typename T1::elem_type>& out, const Proxy<T1>& P, const uword dim)
  42. {
  43. arma_extra_debug_sigprint();
  44. if(is_Mat<typename Proxy<T1>::stored_type>::value)
  45. {
  46. op_mean::apply_noalias_unwrap(out, P, dim);
  47. }
  48. else
  49. {
  50. op_mean::apply_noalias_proxy(out, P, dim);
  51. }
  52. }
  53. template<typename T1>
  54. inline
  55. void
  56. op_mean::apply_noalias_unwrap(Mat<typename T1::elem_type>& out, const Proxy<T1>& P, const uword dim)
  57. {
  58. arma_extra_debug_sigprint();
  59. typedef typename T1::elem_type eT;
  60. typedef typename get_pod_type<eT>::result T;
  61. typedef typename Proxy<T1>::stored_type P_stored_type;
  62. const unwrap<P_stored_type> tmp(P.Q);
  63. const typename unwrap<P_stored_type>::stored_type& X = tmp.M;
  64. const uword X_n_rows = X.n_rows;
  65. const uword X_n_cols = X.n_cols;
  66. if(dim == 0)
  67. {
  68. out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols);
  69. if(X_n_rows == 0) { return; }
  70. eT* out_mem = out.memptr();
  71. for(uword col=0; col < X_n_cols; ++col)
  72. {
  73. out_mem[col] = op_mean::direct_mean( X.colptr(col), X_n_rows );
  74. }
  75. }
  76. else
  77. if(dim == 1)
  78. {
  79. out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0);
  80. if(X_n_cols == 0) { return; }
  81. eT* out_mem = out.memptr();
  82. for(uword col=0; col < X_n_cols; ++col)
  83. {
  84. const eT* col_mem = X.colptr(col);
  85. for(uword row=0; row < X_n_rows; ++row)
  86. {
  87. out_mem[row] += col_mem[row];
  88. }
  89. }
  90. out /= T(X_n_cols);
  91. for(uword row=0; row < X_n_rows; ++row)
  92. {
  93. if(arma_isfinite(out_mem[row]) == false)
  94. {
  95. out_mem[row] = op_mean::direct_mean_robust( X, row );
  96. }
  97. }
  98. }
  99. }
  100. template<typename T1>
  101. arma_hot
  102. inline
  103. void
  104. op_mean::apply_noalias_proxy(Mat<typename T1::elem_type>& out, const Proxy<T1>& P, const uword dim)
  105. {
  106. arma_extra_debug_sigprint();
  107. typedef typename T1::elem_type eT;
  108. typedef typename get_pod_type<eT>::result T;
  109. const uword P_n_rows = P.get_n_rows();
  110. const uword P_n_cols = P.get_n_cols();
  111. if(dim == 0)
  112. {
  113. out.set_size((P_n_rows > 0) ? 1 : 0, P_n_cols);
  114. if(P_n_rows == 0) { return; }
  115. eT* out_mem = out.memptr();
  116. for(uword col=0; col < P_n_cols; ++col)
  117. {
  118. eT val1 = eT(0);
  119. eT val2 = eT(0);
  120. uword i,j;
  121. for(i=0, j=1; j < P_n_rows; i+=2, j+=2)
  122. {
  123. val1 += P.at(i,col);
  124. val2 += P.at(j,col);
  125. }
  126. if(i < P_n_rows)
  127. {
  128. val1 += P.at(i,col);
  129. }
  130. out_mem[col] = (val1 + val2) / T(P_n_rows);
  131. }
  132. }
  133. else
  134. if(dim == 1)
  135. {
  136. out.zeros(P_n_rows, (P_n_cols > 0) ? 1 : 0);
  137. if(P_n_cols == 0) { return; }
  138. eT* out_mem = out.memptr();
  139. for(uword col=0; col < P_n_cols; ++col)
  140. for(uword row=0; row < P_n_rows; ++row)
  141. {
  142. out_mem[row] += P.at(row,col);
  143. }
  144. out /= T(P_n_cols);
  145. }
  146. if(out.is_finite() == false)
  147. {
  148. // TODO: replace with dedicated handling to avoid unwrapping
  149. op_mean::apply_noalias_unwrap(out, P, dim);
  150. }
  151. }
  152. //
  153. // cubes
  154. template<typename T1>
  155. inline
  156. void
  157. op_mean::apply(Cube<typename T1::elem_type>& out, const OpCube<T1,op_mean>& in)
  158. {
  159. arma_extra_debug_sigprint();
  160. typedef typename T1::elem_type eT;
  161. const uword dim = in.aux_uword_a;
  162. arma_debug_check( (dim > 2), "mean(): parameter 'dim' must be 0 or 1 or 2" );
  163. const ProxyCube<T1> P(in.m);
  164. if(P.is_alias(out) == false)
  165. {
  166. op_mean::apply_noalias(out, P, dim);
  167. }
  168. else
  169. {
  170. Cube<eT> tmp;
  171. op_mean::apply_noalias(tmp, P, dim);
  172. out.steal_mem(tmp);
  173. }
  174. }
  175. template<typename T1>
  176. inline
  177. void
  178. op_mean::apply_noalias(Cube<typename T1::elem_type>& out, const ProxyCube<T1>& P, const uword dim)
  179. {
  180. arma_extra_debug_sigprint();
  181. if(is_Cube<typename ProxyCube<T1>::stored_type>::value)
  182. {
  183. op_mean::apply_noalias_unwrap(out, P, dim);
  184. }
  185. else
  186. {
  187. op_mean::apply_noalias_proxy(out, P, dim);
  188. }
  189. }
  190. template<typename T1>
  191. inline
  192. void
  193. op_mean::apply_noalias_unwrap(Cube<typename T1::elem_type>& out, const ProxyCube<T1>& P, const uword dim)
  194. {
  195. arma_extra_debug_sigprint();
  196. typedef typename T1::elem_type eT;
  197. typedef typename get_pod_type<eT>::result T;
  198. typedef typename ProxyCube<T1>::stored_type P_stored_type;
  199. const unwrap_cube<P_stored_type> U(P.Q);
  200. const Cube<eT>& X = U.M;
  201. const uword X_n_rows = X.n_rows;
  202. const uword X_n_cols = X.n_cols;
  203. const uword X_n_slices = X.n_slices;
  204. if(dim == 0)
  205. {
  206. out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols, X_n_slices);
  207. if(X_n_rows == 0) { return; }
  208. for(uword slice=0; slice < X_n_slices; ++slice)
  209. {
  210. eT* out_mem = out.slice_memptr(slice);
  211. for(uword col=0; col < X_n_cols; ++col)
  212. {
  213. out_mem[col] = op_mean::direct_mean( X.slice_colptr(slice,col), X_n_rows );
  214. }
  215. }
  216. }
  217. else
  218. if(dim == 1)
  219. {
  220. out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0, X_n_slices);
  221. if(X_n_cols == 0) { return; }
  222. for(uword slice=0; slice < X_n_slices; ++slice)
  223. {
  224. eT* out_mem = out.slice_memptr(slice);
  225. for(uword col=0; col < X_n_cols; ++col)
  226. {
  227. const eT* col_mem = X.slice_colptr(slice,col);
  228. for(uword row=0; row < X_n_rows; ++row)
  229. {
  230. out_mem[row] += col_mem[row];
  231. }
  232. }
  233. const Mat<eT> tmp('j', X.slice_memptr(slice), X_n_rows, X_n_cols);
  234. for(uword row=0; row < X_n_rows; ++row)
  235. {
  236. out_mem[row] /= T(X_n_cols);
  237. if(arma_isfinite(out_mem[row]) == false)
  238. {
  239. out_mem[row] = op_mean::direct_mean_robust( tmp, row );
  240. }
  241. }
  242. }
  243. }
  244. else
  245. if(dim == 2)
  246. {
  247. out.zeros(X_n_rows, X_n_cols, (X_n_slices > 0) ? 1 : 0);
  248. if(X_n_slices == 0) { return; }
  249. eT* out_mem = out.memptr();
  250. for(uword slice=0; slice < X_n_slices; ++slice)
  251. {
  252. arrayops::inplace_plus(out_mem, X.slice_memptr(slice), X.n_elem_slice );
  253. }
  254. out /= T(X_n_slices);
  255. podarray<eT> tmp(X_n_slices);
  256. for(uword col=0; col < X_n_cols; ++col)
  257. for(uword row=0; row < X_n_rows; ++row)
  258. {
  259. if(arma_isfinite(out.at(row,col,0)) == false)
  260. {
  261. for(uword slice=0; slice < X_n_slices; ++slice)
  262. {
  263. tmp[slice] = X.at(row,col,slice);
  264. }
  265. out.at(row,col,0) = op_mean::direct_mean_robust(tmp.memptr(), X_n_slices);
  266. }
  267. }
  268. }
  269. }
  270. template<typename T1>
  271. arma_hot
  272. inline
  273. void
  274. op_mean::apply_noalias_proxy(Cube<typename T1::elem_type>& out, const ProxyCube<T1>& P, const uword dim)
  275. {
  276. arma_extra_debug_sigprint();
  277. op_mean::apply_noalias_unwrap(out, P, dim);
  278. // TODO: implement specialised handling
  279. }
  280. //
  281. template<typename eT>
  282. inline
  283. eT
  284. op_mean::direct_mean(const eT* const X, const uword n_elem)
  285. {
  286. arma_extra_debug_sigprint();
  287. typedef typename get_pod_type<eT>::result T;
  288. const eT result = arrayops::accumulate(X, n_elem) / T(n_elem);
  289. return arma_isfinite(result) ? result : op_mean::direct_mean_robust(X, n_elem);
  290. }
  291. template<typename eT>
  292. inline
  293. eT
  294. op_mean::direct_mean_robust(const eT* const X, const uword n_elem)
  295. {
  296. arma_extra_debug_sigprint();
  297. // use an adapted form of the mean finding algorithm from the running_stat class
  298. typedef typename get_pod_type<eT>::result T;
  299. uword i,j;
  300. eT r_mean = eT(0);
  301. for(i=0, j=1; j<n_elem; i+=2, j+=2)
  302. {
  303. const eT Xi = X[i];
  304. const eT Xj = X[j];
  305. r_mean = r_mean + (Xi - r_mean)/T(j); // we need i+1, and j is equivalent to i+1 here
  306. r_mean = r_mean + (Xj - r_mean)/T(j+1);
  307. }
  308. if(i < n_elem)
  309. {
  310. const eT Xi = X[i];
  311. r_mean = r_mean + (Xi - r_mean)/T(i+1);
  312. }
  313. return r_mean;
  314. }
  315. template<typename eT>
  316. inline
  317. eT
  318. op_mean::direct_mean(const Mat<eT>& X, const uword row)
  319. {
  320. arma_extra_debug_sigprint();
  321. typedef typename get_pod_type<eT>::result T;
  322. const uword X_n_cols = X.n_cols;
  323. eT val = eT(0);
  324. uword i,j;
  325. for(i=0, j=1; j < X_n_cols; i+=2, j+=2)
  326. {
  327. val += X.at(row,i);
  328. val += X.at(row,j);
  329. }
  330. if(i < X_n_cols)
  331. {
  332. val += X.at(row,i);
  333. }
  334. const eT result = val / T(X_n_cols);
  335. return arma_isfinite(result) ? result : op_mean::direct_mean_robust(X, row);
  336. }
  337. template<typename eT>
  338. inline
  339. eT
  340. op_mean::direct_mean_robust(const Mat<eT>& X, const uword row)
  341. {
  342. arma_extra_debug_sigprint();
  343. typedef typename get_pod_type<eT>::result T;
  344. const uword X_n_cols = X.n_cols;
  345. eT r_mean = eT(0);
  346. for(uword col=0; col < X_n_cols; ++col)
  347. {
  348. r_mean = r_mean + (X.at(row,col) - r_mean)/T(col+1);
  349. }
  350. return r_mean;
  351. }
  352. template<typename eT>
  353. inline
  354. eT
  355. op_mean::mean_all(const subview<eT>& X)
  356. {
  357. arma_extra_debug_sigprint();
  358. typedef typename get_pod_type<eT>::result T;
  359. const uword X_n_rows = X.n_rows;
  360. const uword X_n_cols = X.n_cols;
  361. const uword X_n_elem = X.n_elem;
  362. if(X_n_elem == 0)
  363. {
  364. arma_debug_check(true, "mean(): object has no elements");
  365. return Datum<eT>::nan;
  366. }
  367. eT val = eT(0);
  368. if(X_n_rows == 1)
  369. {
  370. const Mat<eT>& A = X.m;
  371. const uword start_row = X.aux_row1;
  372. const uword start_col = X.aux_col1;
  373. const uword end_col_p1 = start_col + X_n_cols;
  374. uword i,j;
  375. for(i=start_col, j=start_col+1; j < end_col_p1; i+=2, j+=2)
  376. {
  377. val += A.at(start_row, i);
  378. val += A.at(start_row, j);
  379. }
  380. if(i < end_col_p1)
  381. {
  382. val += A.at(start_row, i);
  383. }
  384. }
  385. else
  386. {
  387. for(uword col=0; col < X_n_cols; ++col)
  388. {
  389. val += arrayops::accumulate(X.colptr(col), X_n_rows);
  390. }
  391. }
  392. const eT result = val / T(X_n_elem);
  393. return arma_isfinite(result) ? result : op_mean::mean_all_robust(X);
  394. }
  395. template<typename eT>
  396. inline
  397. eT
  398. op_mean::mean_all_robust(const subview<eT>& X)
  399. {
  400. arma_extra_debug_sigprint();
  401. typedef typename get_pod_type<eT>::result T;
  402. const uword X_n_rows = X.n_rows;
  403. const uword X_n_cols = X.n_cols;
  404. const uword start_row = X.aux_row1;
  405. const uword start_col = X.aux_col1;
  406. const uword end_row_p1 = start_row + X_n_rows;
  407. const uword end_col_p1 = start_col + X_n_cols;
  408. const Mat<eT>& A = X.m;
  409. eT r_mean = eT(0);
  410. if(X_n_rows == 1)
  411. {
  412. uword i=0;
  413. for(uword col = start_col; col < end_col_p1; ++col, ++i)
  414. {
  415. r_mean = r_mean + (A.at(start_row,col) - r_mean)/T(i+1);
  416. }
  417. }
  418. else
  419. {
  420. uword i=0;
  421. for(uword col = start_col; col < end_col_p1; ++col)
  422. for(uword row = start_row; row < end_row_p1; ++row, ++i)
  423. {
  424. r_mean = r_mean + (A.at(row,col) - r_mean)/T(i+1);
  425. }
  426. }
  427. return r_mean;
  428. }
  429. template<typename eT>
  430. inline
  431. eT
  432. op_mean::mean_all(const diagview<eT>& X)
  433. {
  434. arma_extra_debug_sigprint();
  435. typedef typename get_pod_type<eT>::result T;
  436. const uword X_n_elem = X.n_elem;
  437. if(X_n_elem == 0)
  438. {
  439. arma_debug_check(true, "mean(): object has no elements");
  440. return Datum<eT>::nan;
  441. }
  442. eT val = eT(0);
  443. for(uword i=0; i<X_n_elem; ++i)
  444. {
  445. val += X[i];
  446. }
  447. const eT result = val / T(X_n_elem);
  448. return arma_isfinite(result) ? result : op_mean::mean_all_robust(X);
  449. }
  450. template<typename eT>
  451. inline
  452. eT
  453. op_mean::mean_all_robust(const diagview<eT>& X)
  454. {
  455. arma_extra_debug_sigprint();
  456. typedef typename get_pod_type<eT>::result T;
  457. const uword X_n_elem = X.n_elem;
  458. eT r_mean = eT(0);
  459. for(uword i=0; i<X_n_elem; ++i)
  460. {
  461. r_mean = r_mean + (X[i] - r_mean)/T(i+1);
  462. }
  463. return r_mean;
  464. }
  465. template<typename T1>
  466. inline
  467. typename T1::elem_type
  468. op_mean::mean_all(const Op<T1,op_vectorise_col>& X)
  469. {
  470. arma_extra_debug_sigprint();
  471. return op_mean::mean_all(X.m);
  472. }
  473. template<typename T1>
  474. inline
  475. typename T1::elem_type
  476. op_mean::mean_all(const Base<typename T1::elem_type, T1>& X)
  477. {
  478. arma_extra_debug_sigprint();
  479. typedef typename T1::elem_type eT;
  480. const quasi_unwrap<T1> tmp(X.get_ref());
  481. const Mat<eT>& A = tmp.M;
  482. const uword A_n_elem = A.n_elem;
  483. if(A_n_elem == 0)
  484. {
  485. arma_debug_check(true, "mean(): object has no elements");
  486. return Datum<eT>::nan;
  487. }
  488. return op_mean::direct_mean(A.memptr(), A_n_elem);
  489. }
  490. template<typename eT>
  491. arma_inline
  492. eT
  493. op_mean::robust_mean(const eT A, const eT B)
  494. {
  495. return A + (B - A)/eT(2);
  496. }
  497. template<typename T>
  498. arma_inline
  499. std::complex<T>
  500. op_mean::robust_mean(const std::complex<T>& A, const std::complex<T>& B)
  501. {
  502. return A + (B - A)/T(2);
  503. }
  504. //! @}