mul_gemm.hpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup gemm
  16. //! @{
  17. //! for tiny square matrices, size <= 4x4
  18. template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
  19. class gemm_emul_tinysq
  20. {
  21. public:
  22. template<typename eT, typename TA, typename TB>
  23. arma_cold
  24. inline
  25. static
  26. void
  27. apply
  28. (
  29. Mat<eT>& C,
  30. const TA& A,
  31. const TB& B,
  32. const eT alpha = eT(1),
  33. const eT beta = eT(0)
  34. )
  35. {
  36. arma_extra_debug_sigprint();
  37. switch(A.n_rows)
  38. {
  39. case 4: gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta );
  40. // fallthrough
  41. case 3: gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta );
  42. // fallthrough
  43. case 2: gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta );
  44. // fallthrough
  45. case 1: gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta );
  46. // fallthrough
  47. default: ;
  48. }
  49. }
  50. };
  51. //! emulation of gemm(), for non-complex matrices only, as it assumes only simple transposes (ie. doesn't do hermitian transposes)
  52. template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
  53. class gemm_emul_large
  54. {
  55. public:
  56. template<typename eT, typename TA, typename TB>
  57. arma_hot
  58. inline
  59. static
  60. void
  61. apply
  62. (
  63. Mat<eT>& C,
  64. const TA& A,
  65. const TB& B,
  66. const eT alpha = eT(1),
  67. const eT beta = eT(0)
  68. )
  69. {
  70. arma_extra_debug_sigprint();
  71. const uword A_n_rows = A.n_rows;
  72. const uword A_n_cols = A.n_cols;
  73. const uword B_n_rows = B.n_rows;
  74. const uword B_n_cols = B.n_cols;
  75. if( (do_trans_A == false) && (do_trans_B == false) )
  76. {
  77. arma_aligned podarray<eT> tmp(A_n_cols);
  78. eT* A_rowdata = tmp.memptr();
  79. for(uword row_A=0; row_A < A_n_rows; ++row_A)
  80. {
  81. tmp.copy_row(A, row_A);
  82. for(uword col_B=0; col_B < B_n_cols; ++col_B)
  83. {
  84. const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B));
  85. if( (use_alpha == false) && (use_beta == false) ) { C.at(row_A,col_B) = acc; }
  86. else if( (use_alpha == true ) && (use_beta == false) ) { C.at(row_A,col_B) = alpha*acc; }
  87. else if( (use_alpha == false) && (use_beta == true ) ) { C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); }
  88. else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); }
  89. }
  90. }
  91. }
  92. else
  93. if( (do_trans_A == true) && (do_trans_B == false) )
  94. {
  95. for(uword col_A=0; col_A < A_n_cols; ++col_A)
  96. {
  97. // col_A is interpreted as row_A when storing the results in matrix C
  98. const eT* A_coldata = A.colptr(col_A);
  99. for(uword col_B=0; col_B < B_n_cols; ++col_B)
  100. {
  101. const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B));
  102. if( (use_alpha == false) && (use_beta == false) ) { C.at(col_A,col_B) = acc; }
  103. else if( (use_alpha == true ) && (use_beta == false) ) { C.at(col_A,col_B) = alpha*acc; }
  104. else if( (use_alpha == false) && (use_beta == true ) ) { C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); }
  105. else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); }
  106. }
  107. }
  108. }
  109. else
  110. if( (do_trans_A == false) && (do_trans_B == true) )
  111. {
  112. Mat<eT> BB;
  113. op_strans::apply_mat_noalias(BB, B);
  114. gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
  115. }
  116. else
  117. if( (do_trans_A == true) && (do_trans_B == true) )
  118. {
  119. // mat B_tmp = trans(B);
  120. // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
  121. // By using the trans(A)*trans(B) = trans(B*A) equivalency,
  122. // transpose operations are not needed
  123. arma_aligned podarray<eT> tmp(B.n_cols);
  124. eT* B_rowdata = tmp.memptr();
  125. for(uword row_B=0; row_B < B_n_rows; ++row_B)
  126. {
  127. tmp.copy_row(B, row_B);
  128. for(uword col_A=0; col_A < A_n_cols; ++col_A)
  129. {
  130. const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A));
  131. if( (use_alpha == false) && (use_beta == false) ) { C.at(col_A,row_B) = acc; }
  132. else if( (use_alpha == true ) && (use_beta == false) ) { C.at(col_A,row_B) = alpha*acc; }
  133. else if( (use_alpha == false) && (use_beta == true ) ) { C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); }
  134. else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); }
  135. }
  136. }
  137. }
  138. }
  139. };
  140. template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
  141. class gemm_emul
  142. {
  143. public:
  144. template<typename eT, typename TA, typename TB>
  145. arma_hot
  146. inline
  147. static
  148. void
  149. apply
  150. (
  151. Mat<eT>& C,
  152. const TA& A,
  153. const TB& B,
  154. const eT alpha = eT(1),
  155. const eT beta = eT(0),
  156. const typename arma_not_cx<eT>::result* junk = 0
  157. )
  158. {
  159. arma_extra_debug_sigprint();
  160. arma_ignore(junk);
  161. gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
  162. }
  163. template<typename eT>
  164. arma_hot
  165. inline
  166. static
  167. void
  168. apply
  169. (
  170. Mat<eT>& C,
  171. const Mat<eT>& A,
  172. const Mat<eT>& B,
  173. const eT alpha = eT(1),
  174. const eT beta = eT(0),
  175. const typename arma_cx_only<eT>::result* junk = 0
  176. )
  177. {
  178. arma_extra_debug_sigprint();
  179. arma_ignore(junk);
  180. // "better than nothing" handling of hermitian transposes for complex number matrices
  181. Mat<eT> tmp_A;
  182. Mat<eT> tmp_B;
  183. if(do_trans_A) { op_htrans::apply_mat_noalias(tmp_A, A); }
  184. if(do_trans_B) { op_htrans::apply_mat_noalias(tmp_B, B); }
  185. const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
  186. const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B;
  187. gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
  188. }
  189. };
  190. //! \brief
  191. //! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm.
  192. //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
  193. template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
  194. class gemm
  195. {
  196. public:
  197. template<typename eT, typename TA, typename TB>
  198. inline
  199. static
  200. void
  201. apply_blas_type( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) )
  202. {
  203. arma_extra_debug_sigprint();
  204. if( (A.n_rows <= 4) && (A.n_rows == A.n_cols) && (A.n_rows == B.n_rows) && (B.n_rows == B.n_cols) && (is_cx<eT>::no) )
  205. {
  206. if(do_trans_B == false)
  207. {
  208. gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
  209. }
  210. else
  211. {
  212. Mat<eT> BB(B.n_rows, B.n_rows);
  213. op_strans::apply_mat_noalias_tinysq(BB, B);
  214. gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
  215. }
  216. }
  217. else
  218. {
  219. #if defined(ARMA_USE_ATLAS)
  220. {
  221. arma_extra_debug_print("atlas::cblas_gemm()");
  222. arma_debug_assert_atlas_size(A,B);
  223. atlas::cblas_gemm<eT>
  224. (
  225. atlas::CblasColMajor,
  226. (do_trans_A) ? ( is_cx<eT>::yes ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
  227. (do_trans_B) ? ( is_cx<eT>::yes ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
  228. C.n_rows,
  229. C.n_cols,
  230. (do_trans_A) ? A.n_rows : A.n_cols,
  231. (use_alpha) ? alpha : eT(1),
  232. A.mem,
  233. (do_trans_A) ? A.n_rows : C.n_rows,
  234. B.mem,
  235. (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
  236. (use_beta) ? beta : eT(0),
  237. C.memptr(),
  238. C.n_rows
  239. );
  240. }
  241. #elif defined(ARMA_USE_BLAS)
  242. {
  243. arma_extra_debug_print("blas::gemm()");
  244. arma_debug_assert_blas_size(A,B);
  245. const char trans_A = (do_trans_A) ? ( is_cx<eT>::yes ? 'C' : 'T' ) : 'N';
  246. const char trans_B = (do_trans_B) ? ( is_cx<eT>::yes ? 'C' : 'T' ) : 'N';
  247. const blas_int m = blas_int(C.n_rows);
  248. const blas_int n = blas_int(C.n_cols);
  249. const blas_int k = (do_trans_A) ? blas_int(A.n_rows) : blas_int(A.n_cols);
  250. const eT local_alpha = (use_alpha) ? alpha : eT(1);
  251. const blas_int lda = (do_trans_A) ? k : m;
  252. const blas_int ldb = (do_trans_B) ? n : k;
  253. const eT local_beta = (use_beta) ? beta : eT(0);
  254. arma_extra_debug_print( arma_str::format("blas::gemm(): trans_A = %c") % trans_A );
  255. arma_extra_debug_print( arma_str::format("blas::gemm(): trans_B = %c") % trans_B );
  256. blas::gemm<eT>
  257. (
  258. &trans_A,
  259. &trans_B,
  260. &m,
  261. &n,
  262. &k,
  263. &local_alpha,
  264. A.mem,
  265. &lda,
  266. B.mem,
  267. &ldb,
  268. &local_beta,
  269. C.memptr(),
  270. &m
  271. );
  272. }
  273. #else
  274. {
  275. gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
  276. }
  277. #endif
  278. }
  279. }
  280. //! immediate multiplication of matrices A and B, storing the result in C
  281. template<typename eT, typename TA, typename TB>
  282. inline
  283. static
  284. void
  285. apply( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) )
  286. {
  287. gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
  288. }
  289. template<typename TA, typename TB>
  290. arma_inline
  291. static
  292. void
  293. apply
  294. (
  295. Mat<float>& C,
  296. const TA& A,
  297. const TB& B,
  298. const float alpha = float(1),
  299. const float beta = float(0)
  300. )
  301. {
  302. gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
  303. }
  304. template<typename TA, typename TB>
  305. arma_inline
  306. static
  307. void
  308. apply
  309. (
  310. Mat<double>& C,
  311. const TA& A,
  312. const TB& B,
  313. const double alpha = double(1),
  314. const double beta = double(0)
  315. )
  316. {
  317. gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
  318. }
  319. template<typename TA, typename TB>
  320. arma_inline
  321. static
  322. void
  323. apply
  324. (
  325. Mat< std::complex<float> >& C,
  326. const TA& A,
  327. const TB& B,
  328. const std::complex<float> alpha = std::complex<float>(1),
  329. const std::complex<float> beta = std::complex<float>(0)
  330. )
  331. {
  332. gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
  333. }
  334. template<typename TA, typename TB>
  335. arma_inline
  336. static
  337. void
  338. apply
  339. (
  340. Mat< std::complex<double> >& C,
  341. const TA& A,
  342. const TB& B,
  343. const std::complex<double> alpha = std::complex<double>(1),
  344. const std::complex<double> beta = std::complex<double>(0)
  345. )
  346. {
  347. gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
  348. }
  349. };
  350. //! @}