mul_syrk.hpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup syrk
  16. //! @{
  17. class syrk_helper
  18. {
  19. public:
  20. template<typename eT>
  21. inline
  22. static
  23. void
  24. inplace_copy_upper_tri_to_lower_tri(Mat<eT>& C)
  25. {
  26. // under the assumption that C is a square matrix
  27. const uword N = C.n_rows;
  28. for(uword k=0; k < N; ++k)
  29. {
  30. eT* colmem = C.colptr(k);
  31. uword i, j;
  32. for(i=(k+1), j=(k+2); j < N; i+=2, j+=2)
  33. {
  34. const eT tmp_i = C.at(k,i);
  35. const eT tmp_j = C.at(k,j);
  36. colmem[i] = tmp_i;
  37. colmem[j] = tmp_j;
  38. }
  39. if(i < N)
  40. {
  41. colmem[i] = C.at(k,i);
  42. }
  43. }
  44. }
  45. };
  46. //! partial emulation of BLAS function syrk(), specialised for A being a vector
  47. template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
  48. class syrk_vec
  49. {
  50. public:
  51. template<typename eT, typename TA>
  52. arma_hot
  53. inline
  54. static
  55. void
  56. apply
  57. (
  58. Mat<eT>& C,
  59. const TA& A,
  60. const eT alpha = eT(1),
  61. const eT beta = eT(0)
  62. )
  63. {
  64. arma_extra_debug_sigprint();
  65. const uword A_n1 = (do_trans_A == false) ? A.n_rows : A.n_cols;
  66. const uword A_n2 = (do_trans_A == false) ? A.n_cols : A.n_rows;
  67. const eT* A_mem = A.memptr();
  68. if(A_n1 == 1)
  69. {
  70. const eT acc1 = op_dot::direct_dot(A_n2, A_mem, A_mem);
  71. if( (use_alpha == false) && (use_beta == false) ) { C[0] = acc1; }
  72. else if( (use_alpha == true ) && (use_beta == false) ) { C[0] = alpha*acc1; }
  73. else if( (use_alpha == false) && (use_beta == true ) ) { C[0] = acc1 + beta*C[0]; }
  74. else if( (use_alpha == true ) && (use_beta == true ) ) { C[0] = alpha*acc1 + beta*C[0]; }
  75. }
  76. else
  77. for(uword k=0; k < A_n1; ++k)
  78. {
  79. const eT A_k = A_mem[k];
  80. uword i,j;
  81. for(i=(k), j=(k+1); j < A_n1; i+=2, j+=2)
  82. {
  83. const eT acc1 = A_k * A_mem[i];
  84. const eT acc2 = A_k * A_mem[j];
  85. if( (use_alpha == false) && (use_beta == false) )
  86. {
  87. C.at(k, i) = acc1;
  88. C.at(k, j) = acc2;
  89. C.at(i, k) = acc1;
  90. C.at(j, k) = acc2;
  91. }
  92. else
  93. if( (use_alpha == true ) && (use_beta == false) )
  94. {
  95. const eT val1 = alpha*acc1;
  96. const eT val2 = alpha*acc2;
  97. C.at(k, i) = val1;
  98. C.at(k, j) = val2;
  99. C.at(i, k) = val1;
  100. C.at(j, k) = val2;
  101. }
  102. else
  103. if( (use_alpha == false) && (use_beta == true) )
  104. {
  105. C.at(k, i) = acc1 + beta*C.at(k, i);
  106. C.at(k, j) = acc2 + beta*C.at(k, j);
  107. if(i != k) { C.at(i, k) = acc1 + beta*C.at(i, k); }
  108. C.at(j, k) = acc2 + beta*C.at(j, k);
  109. }
  110. else
  111. if( (use_alpha == true ) && (use_beta == true) )
  112. {
  113. const eT val1 = alpha*acc1;
  114. const eT val2 = alpha*acc2;
  115. C.at(k, i) = val1 + beta*C.at(k, i);
  116. C.at(k, j) = val2 + beta*C.at(k, j);
  117. if(i != k) { C.at(i, k) = val1 + beta*C.at(i, k); }
  118. C.at(j, k) = val2 + beta*C.at(j, k);
  119. }
  120. }
  121. if(i < A_n1)
  122. {
  123. const eT acc1 = A_k * A_mem[i];
  124. if( (use_alpha == false) && (use_beta == false) )
  125. {
  126. C.at(k, i) = acc1;
  127. C.at(i, k) = acc1;
  128. }
  129. else
  130. if( (use_alpha == true) && (use_beta == false) )
  131. {
  132. const eT val1 = alpha*acc1;
  133. C.at(k, i) = val1;
  134. C.at(i, k) = val1;
  135. }
  136. else
  137. if( (use_alpha == false) && (use_beta == true) )
  138. {
  139. C.at(k, i) = acc1 + beta*C.at(k, i);
  140. if(i != k) { C.at(i, k) = acc1 + beta*C.at(i, k); }
  141. }
  142. else
  143. if( (use_alpha == true) && (use_beta == true) )
  144. {
  145. const eT val1 = alpha*acc1;
  146. C.at(k, i) = val1 + beta*C.at(k, i);
  147. if(i != k) { C.at(i, k) = val1 + beta*C.at(i, k); }
  148. }
  149. }
  150. }
  151. }
  152. };
  153. //! partial emulation of BLAS function syrk()
  154. template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
  155. class syrk_emul
  156. {
  157. public:
  158. template<typename eT, typename TA>
  159. arma_hot
  160. inline
  161. static
  162. void
  163. apply
  164. (
  165. Mat<eT>& C,
  166. const TA& A,
  167. const eT alpha = eT(1),
  168. const eT beta = eT(0)
  169. )
  170. {
  171. arma_extra_debug_sigprint();
  172. // do_trans_A == false -> C = alpha * A * A^T + beta*C
  173. // do_trans_A == true -> C = alpha * A^T * A + beta*C
  174. if(do_trans_A == false)
  175. {
  176. Mat<eT> AA;
  177. op_strans::apply_mat_noalias(AA, A);
  178. syrk_emul<true, use_alpha, use_beta>::apply(C, AA, alpha, beta);
  179. }
  180. else
  181. if(do_trans_A == true)
  182. {
  183. const uword A_n_rows = A.n_rows;
  184. const uword A_n_cols = A.n_cols;
  185. for(uword col_A=0; col_A < A_n_cols; ++col_A)
  186. {
  187. // col_A is interpreted as row_A when storing the results in matrix C
  188. const eT* A_coldata = A.colptr(col_A);
  189. for(uword k=col_A; k < A_n_cols; ++k)
  190. {
  191. const eT acc = op_dot::direct_dot_arma(A_n_rows, A_coldata, A.colptr(k));
  192. if( (use_alpha == false) && (use_beta == false) )
  193. {
  194. C.at(col_A, k) = acc;
  195. C.at(k, col_A) = acc;
  196. }
  197. else
  198. if( (use_alpha == true ) && (use_beta == false) )
  199. {
  200. const eT val = alpha*acc;
  201. C.at(col_A, k) = val;
  202. C.at(k, col_A) = val;
  203. }
  204. else
  205. if( (use_alpha == false) && (use_beta == true ) )
  206. {
  207. C.at(col_A, k) = acc + beta*C.at(col_A, k);
  208. if(col_A != k) { C.at(k, col_A) = acc + beta*C.at(k, col_A); }
  209. }
  210. else
  211. if( (use_alpha == true ) && (use_beta == true ) )
  212. {
  213. const eT val = alpha*acc;
  214. C.at(col_A, k) = val + beta*C.at(col_A, k);
  215. if(col_A != k) { C.at(k, col_A) = val + beta*C.at(k, col_A); }
  216. }
  217. }
  218. }
  219. }
  220. }
  221. };
  222. template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
  223. class syrk
  224. {
  225. public:
  226. template<typename eT, typename TA>
  227. inline
  228. static
  229. void
  230. apply_blas_type( Mat<eT>& C, const TA& A, const eT alpha = eT(1), const eT beta = eT(0) )
  231. {
  232. arma_extra_debug_sigprint();
  233. if(A.is_vec())
  234. {
  235. // work around poor handling of vectors by syrk() in ATLAS 3.8.4 and standard BLAS
  236. syrk_vec<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
  237. return;
  238. }
  239. const uword threshold = (is_cx<eT>::yes ? 16u : 48u);
  240. if( A.n_elem <= threshold )
  241. {
  242. syrk_emul<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
  243. }
  244. else
  245. {
  246. #if defined(ARMA_USE_ATLAS)
  247. {
  248. if(use_beta == true)
  249. {
  250. // use a temporary matrix, as we can't assume that matrix C is already symmetric
  251. Mat<eT> D(C.n_rows, C.n_cols);
  252. syrk<do_trans_A, use_alpha, false>::apply_blas_type(D,A,alpha);
  253. // NOTE: assuming beta=1; this is okay for now, as currently glue_times only uses beta=1
  254. arrayops::inplace_plus(C.memptr(), D.memptr(), C.n_elem);
  255. return;
  256. }
  257. atlas::cblas_syrk<eT>
  258. (
  259. atlas::CblasColMajor,
  260. atlas::CblasUpper,
  261. (do_trans_A) ? atlas::CblasTrans : atlas::CblasNoTrans,
  262. C.n_cols,
  263. (do_trans_A) ? A.n_rows : A.n_cols,
  264. (use_alpha) ? alpha : eT(1),
  265. A.mem,
  266. (do_trans_A) ? A.n_rows : C.n_cols,
  267. (use_beta) ? beta : eT(0),
  268. C.memptr(),
  269. C.n_cols
  270. );
  271. syrk_helper::inplace_copy_upper_tri_to_lower_tri(C);
  272. }
  273. #elif defined(ARMA_USE_BLAS)
  274. {
  275. if(use_beta == true)
  276. {
  277. // use a temporary matrix, as we can't assume that matrix C is already symmetric
  278. Mat<eT> D(C.n_rows, C.n_cols);
  279. syrk<do_trans_A, use_alpha, false>::apply_blas_type(D,A,alpha);
  280. // NOTE: assuming beta=1; this is okay for now, as currently glue_times only uses beta=1
  281. arrayops::inplace_plus(C.memptr(), D.memptr(), C.n_elem);
  282. return;
  283. }
  284. arma_extra_debug_print("blas::syrk()");
  285. const char uplo = 'U';
  286. const char trans_A = (do_trans_A) ? 'T' : 'N';
  287. const blas_int n = blas_int(C.n_cols);
  288. const blas_int k = (do_trans_A) ? blas_int(A.n_rows) : blas_int(A.n_cols);
  289. const eT local_alpha = (use_alpha) ? alpha : eT(1);
  290. const eT local_beta = (use_beta) ? beta : eT(0);
  291. const blas_int lda = (do_trans_A) ? k : n;
  292. arma_extra_debug_print( arma_str::format("blas::syrk(): trans_A = %c") % trans_A );
  293. blas::syrk<eT>
  294. (
  295. &uplo,
  296. &trans_A,
  297. &n,
  298. &k,
  299. &local_alpha,
  300. A.mem,
  301. &lda,
  302. &local_beta,
  303. C.memptr(),
  304. &n // &ldc
  305. );
  306. syrk_helper::inplace_copy_upper_tri_to_lower_tri(C);
  307. }
  308. #else
  309. {
  310. syrk_emul<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
  311. }
  312. #endif
  313. }
  314. }
  315. template<typename eT, typename TA>
  316. inline
  317. static
  318. void
  319. apply( Mat<eT>& C, const TA& A, const eT alpha = eT(1), const eT beta = eT(0) )
  320. {
  321. if(is_cx<eT>::no)
  322. {
  323. if(A.is_vec())
  324. {
  325. syrk_vec<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
  326. }
  327. else
  328. {
  329. syrk_emul<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
  330. }
  331. }
  332. else
  333. {
  334. // handling of complex matrix by syrk_emul() is not yet implemented
  335. return;
  336. }
  337. }
  338. template<typename TA>
  339. arma_inline
  340. static
  341. void
  342. apply
  343. (
  344. Mat<float>& C,
  345. const TA& A,
  346. const float alpha = float(1),
  347. const float beta = float(0)
  348. )
  349. {
  350. syrk<do_trans_A, use_alpha, use_beta>::apply_blas_type(C,A,alpha,beta);
  351. }
  352. template<typename TA>
  353. arma_inline
  354. static
  355. void
  356. apply
  357. (
  358. Mat<double>& C,
  359. const TA& A,
  360. const double alpha = double(1),
  361. const double beta = double(0)
  362. )
  363. {
  364. syrk<do_trans_A, use_alpha, use_beta>::apply_blas_type(C,A,alpha,beta);
  365. }
  366. template<typename TA>
  367. arma_inline
  368. static
  369. void
  370. apply
  371. (
  372. Mat< std::complex<float> >& C,
  373. const TA& A,
  374. const std::complex<float> alpha = std::complex<float>(1),
  375. const std::complex<float> beta = std::complex<float>(0)
  376. )
  377. {
  378. arma_ignore(C);
  379. arma_ignore(A);
  380. arma_ignore(alpha);
  381. arma_ignore(beta);
  382. // handling of complex matrix by syrk() is not yet implemented
  383. return;
  384. }
  385. template<typename TA>
  386. arma_inline
  387. static
  388. void
  389. apply
  390. (
  391. Mat< std::complex<double> >& C,
  392. const TA& A,
  393. const std::complex<double> alpha = std::complex<double>(1),
  394. const std::complex<double> beta = std::complex<double>(0)
  395. )
  396. {
  397. arma_ignore(C);
  398. arma_ignore(A);
  399. arma_ignore(alpha);
  400. arma_ignore(beta);
  401. // handling of complex matrix by syrk() is not yet implemented
  402. return;
  403. }
  404. };
  405. //! @}