fn_accu.hpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup fn_accu
  16. //! @{
  17. template<typename T1>
  18. arma_hot
  19. inline
  20. typename T1::elem_type
  21. accu_proxy_linear(const Proxy<T1>& P)
  22. {
  23. arma_extra_debug_sigprint();
  24. typedef typename T1::elem_type eT;
  25. eT val = eT(0);
  26. typename Proxy<T1>::ea_type Pea = P.get_ea();
  27. const uword n_elem = P.get_n_elem();
  28. if( arma_config::openmp && Proxy<T1>::use_mp && mp_gate<eT>::eval(n_elem) )
  29. {
  30. #if defined(ARMA_USE_OPENMP)
  31. {
  32. // NOTE: using parallelisation with manual reduction workaround to take into account complex numbers;
  33. // NOTE: OpenMP versions lower than 4.0 do not support user-defined reduction
  34. const int n_threads_max = mp_thread_limit::get();
  35. const uword n_threads_use = (std::min)(uword(podarray_prealloc_n_elem::val), uword(n_threads_max));
  36. const uword chunk_size = n_elem / n_threads_use;
  37. podarray<eT> partial_accs(n_threads_use);
  38. #pragma omp parallel for schedule(static) num_threads(int(n_threads_use))
  39. for(uword thread_id=0; thread_id < n_threads_use; ++thread_id)
  40. {
  41. const uword start = (thread_id+0) * chunk_size;
  42. const uword endp1 = (thread_id+1) * chunk_size;
  43. eT acc = eT(0);
  44. for(uword i=start; i < endp1; ++i) { acc += Pea[i]; }
  45. partial_accs[thread_id] = acc;
  46. }
  47. for(uword thread_id=0; thread_id < n_threads_use; ++thread_id) { val += partial_accs[thread_id]; }
  48. for(uword i=(n_threads_use*chunk_size); i < n_elem; ++i) { val += Pea[i]; }
  49. }
  50. #endif
  51. }
  52. else
  53. {
  54. #if defined(__FINITE_MATH_ONLY__) && (__FINITE_MATH_ONLY__ > 0)
  55. {
  56. if(P.is_aligned())
  57. {
  58. typename Proxy<T1>::aligned_ea_type Pea_aligned = P.get_aligned_ea();
  59. for(uword i=0; i<n_elem; ++i) { val += Pea_aligned.at_alt(i); }
  60. }
  61. else
  62. {
  63. for(uword i=0; i<n_elem; ++i) { val += Pea[i]; }
  64. }
  65. }
  66. #else
  67. {
  68. eT val1 = eT(0);
  69. eT val2 = eT(0);
  70. uword i,j;
  71. for(i=0, j=1; j < n_elem; i+=2, j+=2) { val1 += Pea[i]; val2 += Pea[j]; }
  72. if(i < n_elem) { val1 += Pea[i]; }
  73. val = val1 + val2;
  74. }
  75. #endif
  76. }
  77. return val;
  78. }
  79. template<typename T1>
  80. arma_hot
  81. inline
  82. typename T1::elem_type
  83. accu_proxy_at_mp(const Proxy<T1>& P)
  84. {
  85. arma_extra_debug_sigprint();
  86. typedef typename T1::elem_type eT;
  87. eT val = eT(0);
  88. #if defined(ARMA_USE_OPENMP)
  89. {
  90. const uword n_rows = P.get_n_rows();
  91. const uword n_cols = P.get_n_cols();
  92. if(n_cols == 1)
  93. {
  94. const int n_threads_max = mp_thread_limit::get();
  95. const uword n_threads_use = (std::min)(uword(podarray_prealloc_n_elem::val), uword(n_threads_max));
  96. const uword chunk_size = n_rows / n_threads_use;
  97. podarray<eT> partial_accs(n_threads_use);
  98. #pragma omp parallel for schedule(static) num_threads(int(n_threads_use))
  99. for(uword thread_id=0; thread_id < n_threads_use; ++thread_id)
  100. {
  101. const uword start = (thread_id+0) * chunk_size;
  102. const uword endp1 = (thread_id+1) * chunk_size;
  103. eT acc = eT(0);
  104. for(uword i=start; i < endp1; ++i) { acc += P.at(i,0); }
  105. partial_accs[thread_id] = acc;
  106. }
  107. for(uword thread_id=0; thread_id < n_threads_use; ++thread_id) { val += partial_accs[thread_id]; }
  108. for(uword i=(n_threads_use*chunk_size); i < n_rows; ++i) { val += P.at(i,0); }
  109. }
  110. else
  111. if(n_rows == 1)
  112. {
  113. const int n_threads_max = mp_thread_limit::get();
  114. const uword n_threads_use = (std::min)(uword(podarray_prealloc_n_elem::val), uword(n_threads_max));
  115. const uword chunk_size = n_cols / n_threads_use;
  116. podarray<eT> partial_accs(n_threads_use);
  117. #pragma omp parallel for schedule(static) num_threads(int(n_threads_use))
  118. for(uword thread_id=0; thread_id < n_threads_use; ++thread_id)
  119. {
  120. const uword start = (thread_id+0) * chunk_size;
  121. const uword endp1 = (thread_id+1) * chunk_size;
  122. eT acc = eT(0);
  123. for(uword i=start; i < endp1; ++i) { acc += P.at(0,i); }
  124. partial_accs[thread_id] = acc;
  125. }
  126. for(uword thread_id=0; thread_id < n_threads_use; ++thread_id) { val += partial_accs[thread_id]; }
  127. for(uword i=(n_threads_use*chunk_size); i < n_cols; ++i) { val += P.at(0,i); }
  128. }
  129. else
  130. {
  131. podarray<eT> col_accs(n_cols);
  132. const int n_threads = mp_thread_limit::get();
  133. #pragma omp parallel for schedule(static) num_threads(n_threads)
  134. for(uword col=0; col < n_cols; ++col)
  135. {
  136. eT val1 = eT(0);
  137. eT val2 = eT(0);
  138. uword i,j;
  139. for(i=0, j=1; j < n_rows; i+=2, j+=2) { val1 += P.at(i,col); val2 += P.at(j,col); }
  140. if(i < n_rows) { val1 += P.at(i,col); }
  141. col_accs[col] = val1 + val2;
  142. }
  143. val = arrayops::accumulate(col_accs.memptr(), n_cols);
  144. }
  145. }
  146. #else
  147. {
  148. arma_ignore(P);
  149. }
  150. #endif
  151. return val;
  152. }
  153. template<typename T1>
  154. arma_hot
  155. inline
  156. typename T1::elem_type
  157. accu_proxy_at(const Proxy<T1>& P)
  158. {
  159. arma_extra_debug_sigprint();
  160. typedef typename T1::elem_type eT;
  161. if(arma_config::openmp && Proxy<T1>::use_mp && mp_gate<eT>::eval(P.get_n_elem()))
  162. {
  163. return accu_proxy_at_mp(P);
  164. }
  165. const uword n_rows = P.get_n_rows();
  166. const uword n_cols = P.get_n_cols();
  167. eT val = eT(0);
  168. if(n_rows != 1)
  169. {
  170. eT val1 = eT(0);
  171. eT val2 = eT(0);
  172. for(uword col=0; col < n_cols; ++col)
  173. {
  174. uword i,j;
  175. for(i=0, j=1; j < n_rows; i+=2, j+=2) { val1 += P.at(i,col); val2 += P.at(j,col); }
  176. if(i < n_rows) { val1 += P.at(i,col); }
  177. }
  178. val = val1 + val2;
  179. }
  180. else
  181. {
  182. for(uword col=0; col < n_cols; ++col) { val += P.at(0,col); }
  183. }
  184. return val;
  185. }
  186. //! accumulate the elements of a matrix
  187. template<typename T1>
  188. arma_warn_unused
  189. arma_hot
  190. inline
  191. typename enable_if2< is_arma_type<T1>::value, typename T1::elem_type >::result
  192. accu(const T1& X)
  193. {
  194. arma_extra_debug_sigprint();
  195. const Proxy<T1> P(X);
  196. if(is_Mat<typename Proxy<T1>::stored_type>::value || is_subview_col<typename Proxy<T1>::stored_type>::value)
  197. {
  198. const quasi_unwrap<typename Proxy<T1>::stored_type> tmp(P.Q);
  199. return arrayops::accumulate(tmp.M.memptr(), tmp.M.n_elem);
  200. }
  201. return (Proxy<T1>::use_at) ? accu_proxy_at(P) : accu_proxy_linear(P);
  202. }
  203. //! explicit handling of multiply-and-accumulate
  204. template<typename T1, typename T2>
  205. arma_warn_unused
  206. inline
  207. typename T1::elem_type
  208. accu(const eGlue<T1,T2,eglue_schur>& expr)
  209. {
  210. arma_extra_debug_sigprint();
  211. typedef eGlue<T1,T2,eglue_schur> expr_type;
  212. typedef typename expr_type::proxy1_type::stored_type P1_stored_type;
  213. typedef typename expr_type::proxy2_type::stored_type P2_stored_type;
  214. const bool have_direct_mem_1 = (is_Mat<P1_stored_type>::value) || (is_subview_col<P1_stored_type>::value);
  215. const bool have_direct_mem_2 = (is_Mat<P2_stored_type>::value) || (is_subview_col<P2_stored_type>::value);
  216. if(have_direct_mem_1 && have_direct_mem_2)
  217. {
  218. const quasi_unwrap<P1_stored_type> tmp1(expr.P1.Q);
  219. const quasi_unwrap<P2_stored_type> tmp2(expr.P2.Q);
  220. return op_dot::direct_dot(tmp1.M.n_elem, tmp1.M.memptr(), tmp2.M.memptr());
  221. }
  222. const Proxy<expr_type> P(expr);
  223. return (Proxy<expr_type>::use_at) ? accu_proxy_at(P) : accu_proxy_linear(P);
  224. }
  225. //! explicit handling of Hamming norm (also known as zero norm)
  226. template<typename T1>
  227. arma_warn_unused
  228. inline
  229. uword
  230. accu(const mtOp<uword,T1,op_rel_noteq>& X)
  231. {
  232. arma_extra_debug_sigprint();
  233. typedef typename T1::elem_type eT;
  234. const eT val = X.aux;
  235. const Proxy<T1> P(X.m);
  236. uword n_nonzero = 0;
  237. if(Proxy<T1>::use_at == false)
  238. {
  239. typedef typename Proxy<T1>::ea_type ea_type;
  240. ea_type A = P.get_ea();
  241. const uword n_elem = P.get_n_elem();
  242. for(uword i=0; i<n_elem; ++i)
  243. {
  244. n_nonzero += (A[i] != val) ? uword(1) : uword(0);
  245. }
  246. }
  247. else
  248. {
  249. const uword P_n_cols = P.get_n_cols();
  250. const uword P_n_rows = P.get_n_rows();
  251. if(P_n_rows == 1)
  252. {
  253. for(uword col=0; col < P_n_cols; ++col)
  254. {
  255. n_nonzero += (P.at(0,col) != val) ? uword(1) : uword(0);
  256. }
  257. }
  258. else
  259. {
  260. for(uword col=0; col < P_n_cols; ++col)
  261. for(uword row=0; row < P_n_rows; ++row)
  262. {
  263. n_nonzero += (P.at(row,col) != val) ? uword(1) : uword(0);
  264. }
  265. }
  266. }
  267. return n_nonzero;
  268. }
  269. template<typename T1>
  270. arma_warn_unused
  271. inline
  272. uword
  273. accu(const mtOp<uword,T1,op_rel_eq>& X)
  274. {
  275. arma_extra_debug_sigprint();
  276. typedef typename T1::elem_type eT;
  277. const eT val = X.aux;
  278. const Proxy<T1> P(X.m);
  279. uword n_nonzero = 0;
  280. if(Proxy<T1>::use_at == false)
  281. {
  282. typedef typename Proxy<T1>::ea_type ea_type;
  283. ea_type A = P.get_ea();
  284. const uword n_elem = P.get_n_elem();
  285. for(uword i=0; i<n_elem; ++i)
  286. {
  287. n_nonzero += (A[i] == val) ? uword(1) : uword(0);
  288. }
  289. }
  290. else
  291. {
  292. const uword P_n_cols = P.get_n_cols();
  293. const uword P_n_rows = P.get_n_rows();
  294. if(P_n_rows == 1)
  295. {
  296. for(uword col=0; col < P_n_cols; ++col)
  297. {
  298. n_nonzero += (P.at(0,col) == val) ? uword(1) : uword(0);
  299. }
  300. }
  301. else
  302. {
  303. for(uword col=0; col < P_n_cols; ++col)
  304. for(uword row=0; row < P_n_rows; ++row)
  305. {
  306. n_nonzero += (P.at(row,col) == val) ? uword(1) : uword(0);
  307. }
  308. }
  309. }
  310. return n_nonzero;
  311. }
  312. template<typename T1, typename T2>
  313. arma_warn_unused
  314. inline
  315. uword
  316. accu(const mtGlue<uword,T1,T2,glue_rel_noteq>& X)
  317. {
  318. arma_extra_debug_sigprint();
  319. const Proxy<T1> PA(X.A);
  320. const Proxy<T2> PB(X.B);
  321. arma_debug_assert_same_size(PA, PB, "operator!=");
  322. uword n_nonzero = 0;
  323. if( (Proxy<T1>::use_at == false) && (Proxy<T2>::use_at == false) )
  324. {
  325. typedef typename Proxy<T1>::ea_type PA_ea_type;
  326. typedef typename Proxy<T2>::ea_type PB_ea_type;
  327. PA_ea_type A = PA.get_ea();
  328. PB_ea_type B = PB.get_ea();
  329. const uword n_elem = PA.get_n_elem();
  330. for(uword i=0; i < n_elem; ++i)
  331. {
  332. n_nonzero += (A[i] != B[i]) ? uword(1) : uword(0);
  333. }
  334. }
  335. else
  336. {
  337. const uword PA_n_cols = PA.get_n_cols();
  338. const uword PA_n_rows = PA.get_n_rows();
  339. if(PA_n_rows == 1)
  340. {
  341. for(uword col=0; col < PA_n_cols; ++col)
  342. {
  343. n_nonzero += (PA.at(0,col) != PB.at(0,col)) ? uword(1) : uword(0);
  344. }
  345. }
  346. else
  347. {
  348. for(uword col=0; col < PA_n_cols; ++col)
  349. for(uword row=0; row < PA_n_rows; ++row)
  350. {
  351. n_nonzero += (PA.at(row,col) != PB.at(row,col)) ? uword(1) : uword(0);
  352. }
  353. }
  354. }
  355. return n_nonzero;
  356. }
  357. template<typename T1, typename T2>
  358. arma_warn_unused
  359. inline
  360. uword
  361. accu(const mtGlue<uword,T1,T2,glue_rel_eq>& X)
  362. {
  363. arma_extra_debug_sigprint();
  364. const Proxy<T1> PA(X.A);
  365. const Proxy<T2> PB(X.B);
  366. arma_debug_assert_same_size(PA, PB, "operator==");
  367. uword n_nonzero = 0;
  368. if( (Proxy<T1>::use_at == false) && (Proxy<T2>::use_at == false) )
  369. {
  370. typedef typename Proxy<T1>::ea_type PA_ea_type;
  371. typedef typename Proxy<T2>::ea_type PB_ea_type;
  372. PA_ea_type A = PA.get_ea();
  373. PB_ea_type B = PB.get_ea();
  374. const uword n_elem = PA.get_n_elem();
  375. for(uword i=0; i < n_elem; ++i)
  376. {
  377. n_nonzero += (A[i] == B[i]) ? uword(1) : uword(0);
  378. }
  379. }
  380. else
  381. {
  382. const uword PA_n_cols = PA.get_n_cols();
  383. const uword PA_n_rows = PA.get_n_rows();
  384. if(PA_n_rows == 1)
  385. {
  386. for(uword col=0; col < PA_n_cols; ++col)
  387. {
  388. n_nonzero += (PA.at(0,col) == PB.at(0,col)) ? uword(1) : uword(0);
  389. }
  390. }
  391. else
  392. {
  393. for(uword col=0; col < PA_n_cols; ++col)
  394. for(uword row=0; row < PA_n_rows; ++row)
  395. {
  396. n_nonzero += (PA.at(row,col) == PB.at(row,col)) ? uword(1) : uword(0);
  397. }
  398. }
  399. }
  400. return n_nonzero;
  401. }
  402. //! accumulate the elements of a subview (submatrix)
  403. template<typename eT>
  404. arma_warn_unused
  405. arma_hot
  406. inline
  407. eT
  408. accu(const subview<eT>& X)
  409. {
  410. arma_extra_debug_sigprint();
  411. const uword X_n_rows = X.n_rows;
  412. const uword X_n_cols = X.n_cols;
  413. eT val = eT(0);
  414. if(X_n_rows == 1)
  415. {
  416. typedef subview_row<eT> sv_type;
  417. const sv_type& sv = reinterpret_cast<const sv_type&>(X); // subview_row<eT> is a child class of subview<eT> and has no extra data
  418. const Proxy<sv_type> P(sv);
  419. val = accu_proxy_linear(P);
  420. }
  421. else
  422. if(X_n_cols == 1)
  423. {
  424. val = arrayops::accumulate( X.colptr(0), X_n_rows );
  425. }
  426. else
  427. {
  428. for(uword col=0; col < X_n_cols; ++col)
  429. {
  430. val += arrayops::accumulate( X.colptr(col), X_n_rows );
  431. }
  432. }
  433. return val;
  434. }
  435. template<typename eT>
  436. arma_warn_unused
  437. arma_hot
  438. inline
  439. eT
  440. accu(const subview_col<eT>& X)
  441. {
  442. arma_extra_debug_sigprint();
  443. return arrayops::accumulate( X.colptr(0), X.n_rows );
  444. }
  445. //
  446. template<typename T1>
  447. arma_hot
  448. inline
  449. typename T1::elem_type
  450. accu_cube_proxy_linear(const ProxyCube<T1>& P)
  451. {
  452. arma_extra_debug_sigprint();
  453. typedef typename T1::elem_type eT;
  454. eT val = eT(0);
  455. typename ProxyCube<T1>::ea_type Pea = P.get_ea();
  456. const uword n_elem = P.get_n_elem();
  457. if( arma_config::openmp && ProxyCube<T1>::use_mp && mp_gate<eT>::eval(n_elem) )
  458. {
  459. #if defined(ARMA_USE_OPENMP)
  460. {
  461. // NOTE: using parallelisation with manual reduction workaround to take into account complex numbers;
  462. // NOTE: OpenMP versions lower than 4.0 do not support user-defined reduction
  463. const int n_threads_max = mp_thread_limit::get();
  464. const uword n_threads_use = (std::min)(uword(podarray_prealloc_n_elem::val), uword(n_threads_max));
  465. const uword chunk_size = n_elem / n_threads_use;
  466. podarray<eT> partial_accs(n_threads_use);
  467. #pragma omp parallel for schedule(static) num_threads(int(n_threads_use))
  468. for(uword thread_id=0; thread_id < n_threads_use; ++thread_id)
  469. {
  470. const uword start = (thread_id+0) * chunk_size;
  471. const uword endp1 = (thread_id+1) * chunk_size;
  472. eT acc = eT(0);
  473. for(uword i=start; i < endp1; ++i) { acc += Pea[i]; }
  474. partial_accs[thread_id] = acc;
  475. }
  476. for(uword thread_id=0; thread_id < n_threads_use; ++thread_id) { val += partial_accs[thread_id]; }
  477. for(uword i=(n_threads_use*chunk_size); i < n_elem; ++i) { val += Pea[i]; }
  478. }
  479. #endif
  480. }
  481. else
  482. {
  483. #if defined(__FINITE_MATH_ONLY__) && (__FINITE_MATH_ONLY__ > 0)
  484. {
  485. if(P.is_aligned())
  486. {
  487. typename ProxyCube<T1>::aligned_ea_type Pea_aligned = P.get_aligned_ea();
  488. for(uword i=0; i<n_elem; ++i) { val += Pea_aligned.at_alt(i); }
  489. }
  490. else
  491. {
  492. for(uword i=0; i<n_elem; ++i) { val += Pea[i]; }
  493. }
  494. }
  495. #else
  496. {
  497. eT val1 = eT(0);
  498. eT val2 = eT(0);
  499. uword i,j;
  500. for(i=0, j=1; j<n_elem; i+=2, j+=2) { val1 += Pea[i]; val2 += Pea[j]; }
  501. if(i < n_elem) { val1 += Pea[i]; }
  502. val = val1 + val2;
  503. }
  504. #endif
  505. }
  506. return val;
  507. }
  508. template<typename T1>
  509. arma_hot
  510. inline
  511. typename T1::elem_type
  512. accu_cube_proxy_at_mp(const ProxyCube<T1>& P)
  513. {
  514. arma_extra_debug_sigprint();
  515. typedef typename T1::elem_type eT;
  516. eT val = eT(0);
  517. #if defined(ARMA_USE_OPENMP)
  518. {
  519. const uword n_rows = P.get_n_rows();
  520. const uword n_cols = P.get_n_cols();
  521. const uword n_slices = P.get_n_slices();
  522. podarray<eT> slice_accs(n_slices);
  523. const int n_threads = mp_thread_limit::get();
  524. #pragma omp parallel for schedule(static) num_threads(n_threads)
  525. for(uword slice = 0; slice < n_slices; ++slice)
  526. {
  527. eT val1 = eT(0);
  528. eT val2 = eT(0);
  529. for(uword col = 0; col < n_cols; ++col)
  530. {
  531. uword i,j;
  532. for(i=0, j=1; j<n_rows; i+=2, j+=2) { val1 += P.at(i,col,slice); val2 += P.at(j,col,slice); }
  533. if(i < n_rows) { val1 += P.at(i,col,slice); }
  534. }
  535. slice_accs[slice] = val1 + val2;
  536. }
  537. val = arrayops::accumulate(slice_accs.memptr(), slice_accs.n_elem);
  538. }
  539. #else
  540. {
  541. arma_ignore(P);
  542. }
  543. #endif
  544. return val;
  545. }
  546. template<typename T1>
  547. arma_hot
  548. inline
  549. typename T1::elem_type
  550. accu_cube_proxy_at(const ProxyCube<T1>& P)
  551. {
  552. arma_extra_debug_sigprint();
  553. typedef typename T1::elem_type eT;
  554. if(arma_config::openmp && ProxyCube<T1>::use_mp && mp_gate<eT>::eval(P.get_n_elem()))
  555. {
  556. return accu_cube_proxy_at_mp(P);
  557. }
  558. const uword n_rows = P.get_n_rows();
  559. const uword n_cols = P.get_n_cols();
  560. const uword n_slices = P.get_n_slices();
  561. eT val1 = eT(0);
  562. eT val2 = eT(0);
  563. for(uword slice = 0; slice < n_slices; ++slice)
  564. for(uword col = 0; col < n_cols; ++col )
  565. {
  566. uword i,j;
  567. for(i=0, j=1; j<n_rows; i+=2, j+=2) { val1 += P.at(i,col,slice); val2 += P.at(j,col,slice); }
  568. if(i < n_rows) { val1 += P.at(i,col,slice); }
  569. }
  570. return (val1 + val2);
  571. }
  572. //! accumulate the elements of a cube
  573. template<typename T1>
  574. arma_warn_unused
  575. arma_hot
  576. inline
  577. typename T1::elem_type
  578. accu(const BaseCube<typename T1::elem_type,T1>& X)
  579. {
  580. arma_extra_debug_sigprint();
  581. const ProxyCube<T1> P(X.get_ref());
  582. if(is_Cube<typename ProxyCube<T1>::stored_type>::value)
  583. {
  584. unwrap_cube<typename ProxyCube<T1>::stored_type> tmp(P.Q);
  585. return arrayops::accumulate(tmp.M.memptr(), tmp.M.n_elem);
  586. }
  587. return (ProxyCube<T1>::use_at) ? accu_cube_proxy_at(P) : accu_cube_proxy_linear(P);
  588. }
  589. //! explicit handling of multiply-and-accumulate (cube version)
  590. template<typename T1, typename T2>
  591. arma_warn_unused
  592. inline
  593. typename T1::elem_type
  594. accu(const eGlueCube<T1,T2,eglue_schur>& expr)
  595. {
  596. arma_extra_debug_sigprint();
  597. typedef eGlueCube<T1,T2,eglue_schur> expr_type;
  598. typedef typename ProxyCube<T1>::stored_type P1_stored_type;
  599. typedef typename ProxyCube<T2>::stored_type P2_stored_type;
  600. if(is_Cube<P1_stored_type>::value && is_Cube<P2_stored_type>::value)
  601. {
  602. const unwrap_cube<P1_stored_type> tmp1(expr.P1.Q);
  603. const unwrap_cube<P2_stored_type> tmp2(expr.P2.Q);
  604. return op_dot::direct_dot(tmp1.M.n_elem, tmp1.M.memptr(), tmp2.M.memptr());
  605. }
  606. const ProxyCube<expr_type> P(expr);
  607. return (ProxyCube<expr_type>::use_at) ? accu_cube_proxy_at(P) : accu_cube_proxy_linear(P);
  608. }
  609. //
  610. template<typename T>
  611. arma_warn_unused
  612. inline
  613. typename arma_scalar_only<T>::result
  614. accu(const T& x)
  615. {
  616. return x;
  617. }
  618. //! accumulate values in a sparse object
  619. template<typename T1>
  620. arma_warn_unused
  621. inline
  622. typename T1::elem_type
  623. accu(const SpBase<typename T1::elem_type,T1>& expr)
  624. {
  625. arma_extra_debug_sigprint();
  626. typedef typename T1::elem_type eT;
  627. const SpProxy<T1> P(expr.get_ref());
  628. if(SpProxy<T1>::use_iterator == false)
  629. {
  630. // direct counting
  631. return arrayops::accumulate(P.get_values(), P.get_n_nonzero());
  632. }
  633. else
  634. {
  635. typename SpProxy<T1>::const_iterator_type it = P.begin();
  636. const uword P_n_nz = P.get_n_nonzero();
  637. eT val = eT(0);
  638. for(uword i=0; i < P_n_nz; ++i) { val += (*it); ++it; }
  639. return val;
  640. }
  641. }
  642. //! explicit handling of accu(A + B), where A and B are sparse matrices
  643. template<typename T1, typename T2>
  644. arma_warn_unused
  645. inline
  646. typename T1::elem_type
  647. accu(const SpGlue<T1,T2,spglue_plus>& expr)
  648. {
  649. arma_extra_debug_sigprint();
  650. const unwrap_spmat<T1> UA(expr.A);
  651. const unwrap_spmat<T2> UB(expr.B);
  652. arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "addition");
  653. return (accu(UA.M) + accu(UB.M));
  654. }
  655. //! explicit handling of accu(A - B), where A and B are sparse matrices
  656. template<typename T1, typename T2>
  657. arma_warn_unused
  658. inline
  659. typename T1::elem_type
  660. accu(const SpGlue<T1,T2,spglue_minus>& expr)
  661. {
  662. arma_extra_debug_sigprint();
  663. const unwrap_spmat<T1> UA(expr.A);
  664. const unwrap_spmat<T2> UB(expr.B);
  665. arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "subtraction");
  666. return (accu(UA.M) - accu(UB.M));
  667. }
  668. //! explicit handling of accu(A % B), where A and B are sparse matrices
  669. template<typename T1, typename T2>
  670. arma_warn_unused
  671. inline
  672. typename T1::elem_type
  673. accu(const SpGlue<T1,T2,spglue_schur>& expr)
  674. {
  675. arma_extra_debug_sigprint();
  676. typedef typename T1::elem_type eT;
  677. const SpProxy<T1> px(expr.A);
  678. const SpProxy<T2> py(expr.B);
  679. typename SpProxy<T1>::const_iterator_type x_it = px.begin();
  680. typename SpProxy<T1>::const_iterator_type x_it_end = px.end();
  681. typename SpProxy<T2>::const_iterator_type y_it = py.begin();
  682. typename SpProxy<T2>::const_iterator_type y_it_end = py.end();
  683. eT acc = eT(0);
  684. while( (x_it != x_it_end) || (y_it != y_it_end) )
  685. {
  686. if(x_it == y_it)
  687. {
  688. acc += ((*x_it) * (*y_it));
  689. ++x_it;
  690. ++y_it;
  691. }
  692. else
  693. {
  694. const uword x_it_col = x_it.col();
  695. const uword x_it_row = x_it.row();
  696. const uword y_it_col = y_it.col();
  697. const uword y_it_row = y_it.row();
  698. if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end
  699. {
  700. ++x_it;
  701. }
  702. else // x is closer to the end
  703. {
  704. ++y_it;
  705. }
  706. }
  707. }
  708. return acc;
  709. }
  710. template<typename T1, typename spop_type>
  711. arma_warn_unused
  712. inline
  713. typename T1::elem_type
  714. accu(const SpOp<T1, spop_type>& expr)
  715. {
  716. arma_extra_debug_sigprint();
  717. typedef typename T1::elem_type eT;
  718. const bool is_vectorise = \
  719. (is_same_type<spop_type, spop_vectorise_row>::yes)
  720. || (is_same_type<spop_type, spop_vectorise_col>::yes)
  721. || (is_same_type<spop_type, spop_vectorise_all>::yes);
  722. if(is_vectorise)
  723. {
  724. return accu(expr.m);
  725. }
  726. const SpMat<eT> tmp = expr;
  727. return accu(tmp);
  728. }
  729. //! @}