fn_trace.hpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup fn_trace
  16. //! @{
  17. template<typename T1>
  18. arma_warn_unused
  19. inline
  20. typename T1::elem_type
  21. trace(const Base<typename T1::elem_type, T1>& X)
  22. {
  23. arma_extra_debug_sigprint();
  24. typedef typename T1::elem_type eT;
  25. const Proxy<T1> P(X.get_ref());
  26. const uword N = (std::min)(P.get_n_rows(), P.get_n_cols());
  27. eT val1 = eT(0);
  28. eT val2 = eT(0);
  29. uword i,j;
  30. for(i=0, j=1; j<N; i+=2, j+=2)
  31. {
  32. val1 += P.at(i,i);
  33. val2 += P.at(j,j);
  34. }
  35. if(i < N)
  36. {
  37. val1 += P.at(i,i);
  38. }
  39. return val1 + val2;
  40. }
  41. template<typename T1>
  42. arma_warn_unused
  43. inline
  44. typename T1::elem_type
  45. trace(const Op<T1, op_diagmat>& X)
  46. {
  47. arma_extra_debug_sigprint();
  48. typedef typename T1::elem_type eT;
  49. const diagmat_proxy<T1> A(X.m);
  50. const uword N = (std::min)(A.n_rows, A.n_cols);
  51. eT val = eT(0);
  52. for(uword i=0; i<N; ++i)
  53. {
  54. val += A[i];
  55. }
  56. return val;
  57. }
  58. //! speedup for trace(A*B); non-complex elements
  59. template<typename T1, typename T2>
  60. arma_warn_unused
  61. inline
  62. typename enable_if2< is_cx<typename T1::elem_type>::no, typename T1::elem_type>::result
  63. trace(const Glue<T1, T2, glue_times>& X)
  64. {
  65. arma_extra_debug_sigprint();
  66. typedef typename T1::elem_type eT;
  67. const partial_unwrap<T1> tmp1(X.A);
  68. const partial_unwrap<T2> tmp2(X.B);
  69. const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
  70. const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
  71. const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
  72. const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
  73. arma_debug_assert_trans_mul_size< partial_unwrap<T1>::do_trans, partial_unwrap<T2>::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
  74. if( (A.n_elem == 0) || (B.n_elem == 0) )
  75. {
  76. return eT(0);
  77. }
  78. const uword A_n_rows = A.n_rows;
  79. const uword A_n_cols = A.n_cols;
  80. const uword B_n_rows = B.n_rows;
  81. const uword B_n_cols = B.n_cols;
  82. eT acc = eT(0);
  83. if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == false) )
  84. {
  85. const uword N = (std::min)(A_n_rows, B_n_cols);
  86. eT acc1 = eT(0);
  87. eT acc2 = eT(0);
  88. for(uword k=0; k < N; ++k)
  89. {
  90. const eT* B_colptr = B.colptr(k);
  91. // condition: A_n_cols = B_n_rows
  92. uword j;
  93. for(j=1; j < A_n_cols; j+=2)
  94. {
  95. const uword i = (j-1);
  96. const eT tmp_i = B_colptr[i];
  97. const eT tmp_j = B_colptr[j];
  98. acc1 += A.at(k, i) * tmp_i;
  99. acc2 += A.at(k, j) * tmp_j;
  100. }
  101. const uword i = (j-1);
  102. if(i < A_n_cols)
  103. {
  104. acc1 += A.at(k, i) * B_colptr[i];
  105. }
  106. }
  107. acc = (acc1 + acc2);
  108. }
  109. else
  110. if( (partial_unwrap<T1>::do_trans == true ) && (partial_unwrap<T2>::do_trans == false) )
  111. {
  112. const uword N = (std::min)(A_n_cols, B_n_cols);
  113. for(uword k=0; k < N; ++k)
  114. {
  115. const eT* A_colptr = A.colptr(k);
  116. const eT* B_colptr = B.colptr(k);
  117. // condition: A_n_rows = B_n_rows
  118. acc += op_dot::direct_dot(A_n_rows, A_colptr, B_colptr);
  119. }
  120. }
  121. else
  122. if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == true ) )
  123. {
  124. const uword N = (std::min)(A_n_rows, B_n_rows);
  125. for(uword k=0; k < N; ++k)
  126. {
  127. // condition: A_n_cols = B_n_cols
  128. for(uword i=0; i < A_n_cols; ++i)
  129. {
  130. acc += A.at(k,i) * B.at(k,i);
  131. }
  132. }
  133. }
  134. else
  135. if( (partial_unwrap<T1>::do_trans == true ) && (partial_unwrap<T2>::do_trans == true ) )
  136. {
  137. const uword N = (std::min)(A_n_cols, B_n_rows);
  138. for(uword k=0; k < N; ++k)
  139. {
  140. const eT* A_colptr = A.colptr(k);
  141. // condition: A_n_rows = B_n_cols
  142. for(uword i=0; i < A_n_rows; ++i)
  143. {
  144. acc += A_colptr[i] * B.at(k,i);
  145. }
  146. }
  147. }
  148. return (use_alpha) ? (alpha * acc) : acc;
  149. }
  150. //! speedup for trace(A*B); complex elements
  151. template<typename T1, typename T2>
  152. arma_warn_unused
  153. inline
  154. typename enable_if2< is_cx<typename T1::elem_type>::yes, typename T1::elem_type>::result
  155. trace(const Glue<T1, T2, glue_times>& X)
  156. {
  157. arma_extra_debug_sigprint();
  158. typedef typename T1::pod_type T;
  159. typedef typename T1::elem_type eT;
  160. const partial_unwrap<T1> tmp1(X.A);
  161. const partial_unwrap<T2> tmp2(X.B);
  162. const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
  163. const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
  164. const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
  165. const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
  166. arma_debug_assert_trans_mul_size< partial_unwrap<T1>::do_trans, partial_unwrap<T2>::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
  167. if( (A.n_elem == 0) || (B.n_elem == 0) )
  168. {
  169. return eT(0);
  170. }
  171. const uword A_n_rows = A.n_rows;
  172. const uword A_n_cols = A.n_cols;
  173. const uword B_n_rows = B.n_rows;
  174. const uword B_n_cols = B.n_cols;
  175. eT acc = eT(0);
  176. if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == false) )
  177. {
  178. const uword N = (std::min)(A_n_rows, B_n_cols);
  179. T acc_real = T(0);
  180. T acc_imag = T(0);
  181. for(uword k=0; k < N; ++k)
  182. {
  183. const eT* B_colptr = B.colptr(k);
  184. // condition: A_n_cols = B_n_rows
  185. for(uword i=0; i < A_n_cols; ++i)
  186. {
  187. // acc += A.at(k, i) * B_colptr[i];
  188. const std::complex<T>& xx = A.at(k, i);
  189. const std::complex<T>& yy = B_colptr[i];
  190. const T a = xx.real();
  191. const T b = xx.imag();
  192. const T c = yy.real();
  193. const T d = yy.imag();
  194. acc_real += (a*c) - (b*d);
  195. acc_imag += (a*d) + (b*c);
  196. }
  197. }
  198. acc = std::complex<T>(acc_real, acc_imag);
  199. }
  200. else
  201. if( (partial_unwrap<T1>::do_trans == true) && (partial_unwrap<T2>::do_trans == false) )
  202. {
  203. const uword N = (std::min)(A_n_cols, B_n_cols);
  204. T acc_real = T(0);
  205. T acc_imag = T(0);
  206. for(uword k=0; k < N; ++k)
  207. {
  208. const eT* A_colptr = A.colptr(k);
  209. const eT* B_colptr = B.colptr(k);
  210. // condition: A_n_rows = B_n_rows
  211. for(uword i=0; i < A_n_rows; ++i)
  212. {
  213. // acc += std::conj(A_colptr[i]) * B_colptr[i];
  214. const std::complex<T>& xx = A_colptr[i];
  215. const std::complex<T>& yy = B_colptr[i];
  216. const T a = xx.real();
  217. const T b = xx.imag();
  218. const T c = yy.real();
  219. const T d = yy.imag();
  220. // take into account the complex conjugate of xx
  221. acc_real += (a*c) + (b*d);
  222. acc_imag += (a*d) - (b*c);
  223. }
  224. }
  225. acc = std::complex<T>(acc_real, acc_imag);
  226. }
  227. else
  228. if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == true) )
  229. {
  230. const uword N = (std::min)(A_n_rows, B_n_rows);
  231. T acc_real = T(0);
  232. T acc_imag = T(0);
  233. for(uword k=0; k < N; ++k)
  234. {
  235. // condition: A_n_cols = B_n_cols
  236. for(uword i=0; i < A_n_cols; ++i)
  237. {
  238. // acc += A.at(k,i) * std::conj(B.at(k,i));
  239. const std::complex<T>& xx = A.at(k, i);
  240. const std::complex<T>& yy = B.at(k, i);
  241. const T a = xx.real();
  242. const T b = xx.imag();
  243. const T c = yy.real();
  244. const T d = -yy.imag(); // take the conjugate
  245. acc_real += (a*c) - (b*d);
  246. acc_imag += (a*d) + (b*c);
  247. }
  248. }
  249. acc = std::complex<T>(acc_real, acc_imag);
  250. }
  251. else
  252. if( (partial_unwrap<T1>::do_trans == true) && (partial_unwrap<T2>::do_trans == true) )
  253. {
  254. const uword N = (std::min)(A_n_cols, B_n_rows);
  255. T acc_real = T(0);
  256. T acc_imag = T(0);
  257. for(uword k=0; k < N; ++k)
  258. {
  259. const eT* A_colptr = A.colptr(k);
  260. // condition: A_n_rows = B_n_cols
  261. for(uword i=0; i < A_n_rows; ++i)
  262. {
  263. // acc += std::conj(A_colptr[i]) * std::conj(B.at(k,i));
  264. const std::complex<T>& xx = A_colptr[i];
  265. const std::complex<T>& yy = B.at(k, i);
  266. const T a = xx.real();
  267. const T b = -xx.imag(); // take the conjugate
  268. const T c = yy.real();
  269. const T d = -yy.imag(); // take the conjugate
  270. acc_real += (a*c) - (b*d);
  271. acc_imag += (a*d) + (b*c);
  272. }
  273. }
  274. acc = std::complex<T>(acc_real, acc_imag);
  275. }
  276. return (use_alpha) ? eT(alpha * acc) : eT(acc);
  277. }
  278. //! trace of sparse object; generic version
  279. template<typename T1>
  280. arma_warn_unused
  281. inline
  282. typename T1::elem_type
  283. trace(const SpBase<typename T1::elem_type,T1>& expr)
  284. {
  285. arma_extra_debug_sigprint();
  286. typedef typename T1::elem_type eT;
  287. const SpProxy<T1> P(expr.get_ref());
  288. const uword N = (std::min)(P.get_n_rows(), P.get_n_cols());
  289. eT acc = eT(0);
  290. if( (is_SpMat<typename SpProxy<T1>::stored_type>::value) && (P.get_n_nonzero() >= 5*N) )
  291. {
  292. const unwrap_spmat<typename SpProxy<T1>::stored_type> U(P.Q);
  293. const SpMat<eT>& X = U.M;
  294. for(uword i=0; i < N; ++i)
  295. {
  296. acc += X.at(i,i); // use binary search
  297. }
  298. }
  299. else
  300. {
  301. typename SpProxy<T1>::const_iterator_type it = P.begin();
  302. const uword P_n_nz = P.get_n_nonzero();
  303. for(uword i=0; i < P_n_nz; ++i)
  304. {
  305. if(it.row() == it.col()) { acc += (*it); }
  306. ++it;
  307. }
  308. }
  309. return acc;
  310. }
  311. //! trace of sparse object; speedup for trace(A + B)
  312. template<typename T1, typename T2>
  313. arma_warn_unused
  314. inline
  315. typename T1::elem_type
  316. trace(const SpGlue<T1, T2, spglue_plus>& expr)
  317. {
  318. arma_extra_debug_sigprint();
  319. const unwrap_spmat<T1> UA(expr.A);
  320. const unwrap_spmat<T2> UB(expr.B);
  321. arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "addition");
  322. return (trace(UA.M) + trace(UB.M));
  323. }
  324. //! trace of sparse object; speedup for trace(A - B)
  325. template<typename T1, typename T2>
  326. arma_warn_unused
  327. inline
  328. typename T1::elem_type
  329. trace(const SpGlue<T1, T2, spglue_minus>& expr)
  330. {
  331. arma_extra_debug_sigprint();
  332. const unwrap_spmat<T1> UA(expr.A);
  333. const unwrap_spmat<T2> UB(expr.B);
  334. arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "subtraction");
  335. return (trace(UA.M) - trace(UB.M));
  336. }
  337. //! trace of sparse object; speedup for trace(A % B)
  338. template<typename T1, typename T2>
  339. arma_warn_unused
  340. inline
  341. typename T1::elem_type
  342. trace(const SpGlue<T1, T2, spglue_schur>& expr)
  343. {
  344. arma_extra_debug_sigprint();
  345. typedef typename T1::elem_type eT;
  346. const unwrap_spmat<T1> UA(expr.A);
  347. const unwrap_spmat<T2> UB(expr.B);
  348. const SpMat<eT>& A = UA.M;
  349. const SpMat<eT>& B = UB.M;
  350. arma_debug_assert_same_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "element-wise multiplication");
  351. const uword N = (std::min)(A.n_rows, A.n_cols);
  352. eT acc = eT(0);
  353. for(uword i=0; i<N; ++i)
  354. {
  355. acc += A.at(i,i) * B.at(i,i);
  356. }
  357. return acc;
  358. }
  359. //! trace of sparse object; speedup for trace(A*B)
  360. template<typename T1, typename T2>
  361. arma_warn_unused
  362. inline
  363. typename T1::elem_type
  364. trace(const SpGlue<T1, T2, spglue_times>& expr)
  365. {
  366. arma_extra_debug_sigprint();
  367. typedef typename T1::elem_type eT;
  368. // better-than-nothing implementation
  369. const unwrap_spmat<T1> UA(expr.A);
  370. const unwrap_spmat<T2> UB(expr.B);
  371. const SpMat<eT>& A = UA.M;
  372. const SpMat<eT>& B = UB.M;
  373. arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
  374. if( (A.n_nonzero == 0) || (B.n_nonzero == 0) )
  375. {
  376. return eT(0);
  377. }
  378. const uword N = (std::min)(A.n_rows, B.n_cols);
  379. eT acc = eT(0);
  380. // TODO: the threshold may need tuning for complex matrices
  381. if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) )
  382. {
  383. for(uword k=0; k < N; ++k)
  384. {
  385. typename SpMat<eT>::const_col_iterator B_it = B.begin_col_no_sync(k);
  386. typename SpMat<eT>::const_col_iterator B_it_end = B.end_col_no_sync(k);
  387. while(B_it != B_it_end)
  388. {
  389. const eT B_val = (*B_it);
  390. const uword i = B_it.row();
  391. acc += A.at(k,i) * B_val;
  392. ++B_it;
  393. }
  394. }
  395. }
  396. else
  397. {
  398. const SpMat<eT> AB = A * B;
  399. acc = trace(AB);
  400. }
  401. return acc;
  402. }
  403. //! trace of sparse object; speedup for trace(A.t()*B); non-complex elements
  404. template<typename T1, typename T2>
  405. arma_warn_unused
  406. inline
  407. typename enable_if2< is_cx<typename T1::elem_type>::no, typename T1::elem_type>::result
  408. trace(const SpGlue<SpOp<T1, spop_htrans>, T2, spglue_times>& expr)
  409. {
  410. arma_extra_debug_sigprint();
  411. typedef typename T1::elem_type eT;
  412. const unwrap_spmat<T1> UA(expr.A.m);
  413. const unwrap_spmat<T2> UB(expr.B);
  414. const SpMat<eT>& A = UA.M;
  415. const SpMat<eT>& B = UB.M;
  416. // NOTE: deliberately swapped A.n_rows and A.n_cols to take into account the requested transpose operation
  417. arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiplication");
  418. if( (A.n_nonzero == 0) || (B.n_nonzero == 0) )
  419. {
  420. return eT(0);
  421. }
  422. const uword N = (std::min)(A.n_cols, B.n_cols);
  423. eT acc = eT(0);
  424. if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) )
  425. {
  426. for(uword k=0; k < N; ++k)
  427. {
  428. typename SpMat<eT>::const_col_iterator B_it = B.begin_col_no_sync(k);
  429. typename SpMat<eT>::const_col_iterator B_it_end = B.end_col_no_sync(k);
  430. while(B_it != B_it_end)
  431. {
  432. const eT B_val = (*B_it);
  433. const uword i = B_it.row();
  434. acc += A.at(i,k) * B_val;
  435. ++B_it;
  436. }
  437. }
  438. }
  439. else
  440. {
  441. const SpMat<eT> AtB = A.t() * B;
  442. acc = trace(AtB);
  443. }
  444. return acc;
  445. }
  446. //! trace of sparse object; speedup for trace(A.t()*B); complex elements
  447. template<typename T1, typename T2>
  448. arma_warn_unused
  449. inline
  450. typename enable_if2< is_cx<typename T1::elem_type>::yes, typename T1::elem_type>::result
  451. trace(const SpGlue<SpOp<T1, spop_htrans>, T2, spglue_times>& expr)
  452. {
  453. arma_extra_debug_sigprint();
  454. typedef typename T1::elem_type eT;
  455. const unwrap_spmat<T1> UA(expr.A.m);
  456. const unwrap_spmat<T2> UB(expr.B);
  457. const SpMat<eT>& A = UA.M;
  458. const SpMat<eT>& B = UB.M;
  459. // NOTE: deliberately swapped A.n_rows and A.n_cols to take into account the requested transpose operation
  460. arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiplication");
  461. if( (A.n_nonzero == 0) || (B.n_nonzero == 0) )
  462. {
  463. return eT(0);
  464. }
  465. const uword N = (std::min)(A.n_cols, B.n_cols);
  466. eT acc = eT(0);
  467. // TODO: the threshold may need tuning for complex matrices
  468. if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) )
  469. {
  470. for(uword k=0; k < N; ++k)
  471. {
  472. typename SpMat<eT>::const_col_iterator B_it = B.begin_col_no_sync(k);
  473. typename SpMat<eT>::const_col_iterator B_it_end = B.end_col_no_sync(k);
  474. while(B_it != B_it_end)
  475. {
  476. const eT B_val = (*B_it);
  477. const uword i = B_it.row();
  478. acc += std::conj(A.at(i,k)) * B_val;
  479. ++B_it;
  480. }
  481. }
  482. }
  483. else
  484. {
  485. const SpMat<eT> AtB = A.t() * B;
  486. acc = trace(AtB);
  487. }
  488. return acc;
  489. }
  490. //! @}