op_sum_meat.hpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup op_sum
  16. //! @{
  17. template<typename T1>
  18. arma_hot
  19. inline
  20. void
  21. op_sum::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_sum>& in)
  22. {
  23. arma_extra_debug_sigprint();
  24. typedef typename T1::elem_type eT;
  25. const uword dim = in.aux_uword_a;
  26. arma_debug_check( (dim > 1), "sum(): parameter 'dim' must be 0 or 1" );
  27. const Proxy<T1> P(in.m);
  28. if(P.is_alias(out) == false)
  29. {
  30. op_sum::apply_noalias(out, P, dim);
  31. }
  32. else
  33. {
  34. Mat<eT> tmp;
  35. op_sum::apply_noalias(tmp, P, dim);
  36. out.steal_mem(tmp);
  37. }
  38. }
  39. template<typename T1>
  40. arma_hot
  41. inline
  42. void
  43. op_sum::apply_noalias(Mat<typename T1::elem_type>& out, const Proxy<T1>& P, const uword dim)
  44. {
  45. arma_extra_debug_sigprint();
  46. if(is_Mat<typename Proxy<T1>::stored_type>::value)
  47. {
  48. op_sum::apply_noalias_unwrap(out, P, dim);
  49. }
  50. else
  51. {
  52. op_sum::apply_noalias_proxy(out, P, dim);
  53. }
  54. }
  55. template<typename T1>
  56. arma_hot
  57. inline
  58. void
  59. op_sum::apply_noalias_unwrap(Mat<typename T1::elem_type>& out, const Proxy<T1>& P, const uword dim)
  60. {
  61. arma_extra_debug_sigprint();
  62. typedef typename T1::elem_type eT;
  63. typedef typename Proxy<T1>::stored_type P_stored_type;
  64. const unwrap<P_stored_type> tmp(P.Q);
  65. const typename unwrap<P_stored_type>::stored_type& X = tmp.M;
  66. const uword X_n_rows = X.n_rows;
  67. const uword X_n_cols = X.n_cols;
  68. if(dim == 0)
  69. {
  70. out.set_size(1, X_n_cols);
  71. eT* out_mem = out.memptr();
  72. for(uword col=0; col < X_n_cols; ++col)
  73. {
  74. out_mem[col] = arrayops::accumulate( X.colptr(col), X_n_rows );
  75. }
  76. }
  77. else
  78. {
  79. out.zeros(X_n_rows, 1);
  80. eT* out_mem = out.memptr();
  81. for(uword col=0; col < X_n_cols; ++col)
  82. {
  83. arrayops::inplace_plus( out_mem, X.colptr(col), X_n_rows );
  84. }
  85. }
  86. }
  87. template<typename T1>
  88. arma_hot
  89. inline
  90. void
  91. op_sum::apply_noalias_proxy(Mat<typename T1::elem_type>& out, const Proxy<T1>& P, const uword dim)
  92. {
  93. arma_extra_debug_sigprint();
  94. typedef typename T1::elem_type eT;
  95. if( arma_config::openmp && Proxy<T1>::use_mp && mp_gate<eT>::eval(P.get_n_elem()) )
  96. {
  97. op_sum::apply_noalias_proxy_mp(out, P, dim);
  98. return;
  99. }
  100. const uword P_n_rows = P.get_n_rows();
  101. const uword P_n_cols = P.get_n_cols();
  102. if(dim == 0)
  103. {
  104. out.set_size(1, P_n_cols);
  105. eT* out_mem = out.memptr();
  106. for(uword col=0; col < P_n_cols; ++col)
  107. {
  108. eT val1 = eT(0);
  109. eT val2 = eT(0);
  110. uword i,j;
  111. for(i=0, j=1; j < P_n_rows; i+=2, j+=2)
  112. {
  113. val1 += P.at(i,col);
  114. val2 += P.at(j,col);
  115. }
  116. if(i < P_n_rows)
  117. {
  118. val1 += P.at(i,col);
  119. }
  120. out_mem[col] = (val1 + val2);
  121. }
  122. }
  123. else
  124. {
  125. out.zeros(P_n_rows, 1);
  126. eT* out_mem = out.memptr();
  127. for(uword col=0; col < P_n_cols; ++col)
  128. for(uword row=0; row < P_n_rows; ++row)
  129. {
  130. out_mem[row] += P.at(row,col);
  131. }
  132. }
  133. }
  134. template<typename T1>
  135. arma_hot
  136. inline
  137. void
  138. op_sum::apply_noalias_proxy_mp(Mat<typename T1::elem_type>& out, const Proxy<T1>& P, const uword dim)
  139. {
  140. arma_extra_debug_sigprint();
  141. #if defined(ARMA_USE_OPENMP)
  142. {
  143. typedef typename T1::elem_type eT;
  144. const uword P_n_rows = P.get_n_rows();
  145. const uword P_n_cols = P.get_n_cols();
  146. const int n_threads = mp_thread_limit::get();
  147. if(dim == 0)
  148. {
  149. out.set_size(1, P_n_cols);
  150. eT* out_mem = out.memptr();
  151. #pragma omp parallel for schedule(static) num_threads(n_threads)
  152. for(uword col=0; col < P_n_cols; ++col)
  153. {
  154. eT val1 = eT(0);
  155. eT val2 = eT(0);
  156. uword i,j;
  157. for(i=0, j=1; j < P_n_rows; i+=2, j+=2)
  158. {
  159. val1 += P.at(i,col);
  160. val2 += P.at(j,col);
  161. }
  162. if(i < P_n_rows)
  163. {
  164. val1 += P.at(i,col);
  165. }
  166. out_mem[col] = (val1 + val2);
  167. }
  168. }
  169. else
  170. {
  171. out.set_size(P_n_rows, 1);
  172. eT* out_mem = out.memptr();
  173. #pragma omp parallel for schedule(static) num_threads(n_threads)
  174. for(uword row=0; row < P_n_rows; ++row)
  175. {
  176. eT acc = eT(0);
  177. for(uword col=0; col < P_n_cols; ++col)
  178. {
  179. acc += P.at(row,col);
  180. }
  181. out_mem[row] = acc;
  182. }
  183. }
  184. }
  185. #else
  186. {
  187. arma_ignore(out);
  188. arma_ignore(P);
  189. arma_ignore(dim);
  190. }
  191. #endif
  192. }
  193. //
  194. // cubes
  195. template<typename T1>
  196. arma_hot
  197. inline
  198. void
  199. op_sum::apply(Cube<typename T1::elem_type>& out, const OpCube<T1,op_sum>& in)
  200. {
  201. arma_extra_debug_sigprint();
  202. typedef typename T1::elem_type eT;
  203. const uword dim = in.aux_uword_a;
  204. arma_debug_check( (dim > 2), "sum(): parameter 'dim' must be 0 or 1 or 2" );
  205. const ProxyCube<T1> P(in.m);
  206. if(P.is_alias(out) == false)
  207. {
  208. op_sum::apply_noalias(out, P, dim);
  209. }
  210. else
  211. {
  212. Cube<eT> tmp;
  213. op_sum::apply_noalias(tmp, P, dim);
  214. out.steal_mem(tmp);
  215. }
  216. }
  217. template<typename T1>
  218. arma_hot
  219. inline
  220. void
  221. op_sum::apply_noalias(Cube<typename T1::elem_type>& out, const ProxyCube<T1>& P, const uword dim)
  222. {
  223. arma_extra_debug_sigprint();
  224. if(is_Cube<typename ProxyCube<T1>::stored_type>::value)
  225. {
  226. op_sum::apply_noalias_unwrap(out, P, dim);
  227. }
  228. else
  229. {
  230. op_sum::apply_noalias_proxy(out, P, dim);
  231. }
  232. }
  233. template<typename T1>
  234. arma_hot
  235. inline
  236. void
  237. op_sum::apply_noalias_unwrap(Cube<typename T1::elem_type>& out, const ProxyCube<T1>& P, const uword dim)
  238. {
  239. arma_extra_debug_sigprint();
  240. typedef typename T1::elem_type eT;
  241. typedef typename ProxyCube<T1>::stored_type P_stored_type;
  242. const unwrap_cube<P_stored_type> tmp(P.Q);
  243. const Cube<eT>& X = tmp.M;
  244. const uword X_n_rows = X.n_rows;
  245. const uword X_n_cols = X.n_cols;
  246. const uword X_n_slices = X.n_slices;
  247. if(dim == 0)
  248. {
  249. out.set_size(1, X_n_cols, X_n_slices);
  250. for(uword slice=0; slice < X_n_slices; ++slice)
  251. {
  252. eT* out_mem = out.slice_memptr(slice);
  253. for(uword col=0; col < X_n_cols; ++col)
  254. {
  255. out_mem[col] = arrayops::accumulate( X.slice_colptr(slice,col), X_n_rows );
  256. }
  257. }
  258. }
  259. else
  260. if(dim == 1)
  261. {
  262. out.zeros(X_n_rows, 1, X_n_slices);
  263. for(uword slice=0; slice < X_n_slices; ++slice)
  264. {
  265. eT* out_mem = out.slice_memptr(slice);
  266. for(uword col=0; col < X_n_cols; ++col)
  267. {
  268. arrayops::inplace_plus( out_mem, X.slice_colptr(slice,col), X_n_rows );
  269. }
  270. }
  271. }
  272. else
  273. if(dim == 2)
  274. {
  275. out.zeros(X_n_rows, X_n_cols, 1);
  276. eT* out_mem = out.memptr();
  277. for(uword slice=0; slice < X_n_slices; ++slice)
  278. {
  279. arrayops::inplace_plus(out_mem, X.slice_memptr(slice), X.n_elem_slice );
  280. }
  281. }
  282. }
  283. template<typename T1>
  284. arma_hot
  285. inline
  286. void
  287. op_sum::apply_noalias_proxy(Cube<typename T1::elem_type>& out, const ProxyCube<T1>& P, const uword dim)
  288. {
  289. arma_extra_debug_sigprint();
  290. typedef typename T1::elem_type eT;
  291. if( arma_config::openmp && ProxyCube<T1>::use_mp && mp_gate<eT>::eval(P.get_n_elem()) )
  292. {
  293. op_sum::apply_noalias_proxy_mp(out, P, dim);
  294. return;
  295. }
  296. const uword P_n_rows = P.get_n_rows();
  297. const uword P_n_cols = P.get_n_cols();
  298. const uword P_n_slices = P.get_n_slices();
  299. if(dim == 0)
  300. {
  301. out.set_size(1, P_n_cols, P_n_slices);
  302. for(uword slice=0; slice < P_n_slices; ++slice)
  303. {
  304. eT* out_mem = out.slice_memptr(slice);
  305. for(uword col=0; col < P_n_cols; ++col)
  306. {
  307. eT val1 = eT(0);
  308. eT val2 = eT(0);
  309. uword i,j;
  310. for(i=0, j=1; j < P_n_rows; i+=2, j+=2)
  311. {
  312. val1 += P.at(i,col,slice);
  313. val2 += P.at(j,col,slice);
  314. }
  315. if(i < P_n_rows)
  316. {
  317. val1 += P.at(i,col,slice);
  318. }
  319. out_mem[col] = (val1 + val2);
  320. }
  321. }
  322. }
  323. else
  324. if(dim == 1)
  325. {
  326. out.zeros(P_n_rows, 1, P_n_slices);
  327. for(uword slice=0; slice < P_n_slices; ++slice)
  328. {
  329. eT* out_mem = out.slice_memptr(slice);
  330. for(uword col=0; col < P_n_cols; ++col)
  331. for(uword row=0; row < P_n_rows; ++row)
  332. {
  333. out_mem[row] += P.at(row,col,slice);
  334. }
  335. }
  336. }
  337. else
  338. if(dim == 2)
  339. {
  340. out.zeros(P_n_rows, P_n_cols, 1);
  341. for(uword slice=0; slice < P_n_slices; ++slice)
  342. {
  343. for(uword col=0; col < P_n_cols; ++col)
  344. {
  345. eT* out_mem = out.slice_colptr(0,col);
  346. for(uword row=0; row < P_n_rows; ++row)
  347. {
  348. out_mem[row] += P.at(row,col,slice);
  349. }
  350. }
  351. }
  352. }
  353. }
  354. template<typename T1>
  355. arma_hot
  356. inline
  357. void
  358. op_sum::apply_noalias_proxy_mp(Cube<typename T1::elem_type>& out, const ProxyCube<T1>& P, const uword dim)
  359. {
  360. arma_extra_debug_sigprint();
  361. #if defined(ARMA_USE_OPENMP)
  362. {
  363. typedef typename T1::elem_type eT;
  364. const uword P_n_rows = P.get_n_rows();
  365. const uword P_n_cols = P.get_n_cols();
  366. const uword P_n_slices = P.get_n_slices();
  367. const int n_threads = mp_thread_limit::get();
  368. if(dim == 0)
  369. {
  370. out.set_size(1, P_n_cols, P_n_slices);
  371. #pragma omp parallel for schedule(static) num_threads(n_threads)
  372. for(uword slice=0; slice < P_n_slices; ++slice)
  373. {
  374. eT* out_mem = out.slice_memptr(slice);
  375. for(uword col=0; col < P_n_cols; ++col)
  376. {
  377. eT val1 = eT(0);
  378. eT val2 = eT(0);
  379. uword i,j;
  380. for(i=0, j=1; j < P_n_rows; i+=2, j+=2)
  381. {
  382. val1 += P.at(i,col,slice);
  383. val2 += P.at(j,col,slice);
  384. }
  385. if(i < P_n_rows)
  386. {
  387. val1 += P.at(i,col,slice);
  388. }
  389. out_mem[col] = (val1 + val2);
  390. }
  391. }
  392. }
  393. else
  394. if(dim == 1)
  395. {
  396. out.zeros(P_n_rows, 1, P_n_slices);
  397. #pragma omp parallel for schedule(static) num_threads(n_threads)
  398. for(uword slice=0; slice < P_n_slices; ++slice)
  399. {
  400. eT* out_mem = out.slice_memptr(slice);
  401. for(uword col=0; col < P_n_cols; ++col)
  402. for(uword row=0; row < P_n_rows; ++row)
  403. {
  404. out_mem[row] += P.at(row,col,slice);
  405. }
  406. }
  407. }
  408. else
  409. if(dim == 2)
  410. {
  411. out.zeros(P_n_rows, P_n_cols, 1);
  412. if(P_n_cols >= P_n_rows)
  413. {
  414. #pragma omp parallel for schedule(static) num_threads(n_threads)
  415. for(uword col=0; col < P_n_cols; ++col)
  416. {
  417. for(uword row=0; row < P_n_rows; ++row)
  418. {
  419. eT acc = eT(0);
  420. for(uword slice=0; slice < P_n_slices; ++slice)
  421. {
  422. acc += P.at(row,col,slice);
  423. }
  424. out.at(row,col,0) = acc;
  425. }
  426. }
  427. }
  428. else
  429. {
  430. #pragma omp parallel for schedule(static) num_threads(n_threads)
  431. for(uword row=0; row < P_n_rows; ++row)
  432. {
  433. for(uword col=0; col < P_n_cols; ++col)
  434. {
  435. eT acc = eT(0);
  436. for(uword slice=0; slice < P_n_slices; ++slice)
  437. {
  438. acc += P.at(row,col,slice);
  439. }
  440. out.at(row,col,0) = acc;
  441. }
  442. }
  443. }
  444. }
  445. }
  446. #else
  447. {
  448. arma_ignore(out);
  449. arma_ignore(P);
  450. arma_ignore(dim);
  451. }
  452. #endif
  453. }
  454. //! @}