mul_gemm_mixed.hpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup gemm_mixed
  16. //! @{
  17. //! \brief
  18. //! Matrix multplication where the matrices have differing element types.
  19. //! Uses caching for speedup.
  20. //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
  21. template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
  22. class gemm_mixed_large
  23. {
  24. public:
  25. template<typename out_eT, typename in_eT1, typename in_eT2>
  26. arma_hot
  27. inline
  28. static
  29. void
  30. apply
  31. (
  32. Mat<out_eT>& C,
  33. const Mat<in_eT1>& A,
  34. const Mat<in_eT2>& B,
  35. const out_eT alpha = out_eT(1),
  36. const out_eT beta = out_eT(0)
  37. )
  38. {
  39. arma_extra_debug_sigprint();
  40. const uword A_n_rows = A.n_rows;
  41. const uword A_n_cols = A.n_cols;
  42. const uword B_n_rows = B.n_rows;
  43. const uword B_n_cols = B.n_cols;
  44. if( (do_trans_A == false) && (do_trans_B == false) )
  45. {
  46. podarray<in_eT1> tmp(A_n_cols);
  47. in_eT1* A_rowdata = tmp.memptr();
  48. #if defined(ARMA_USE_OPENMP)
  49. const bool use_mp = (B_n_cols >= 2) && (B.n_elem >= 8192) && (mp_thread_limit::in_parallel() == false);
  50. #else
  51. const bool use_mp = false;
  52. #endif
  53. if(use_mp)
  54. {
  55. #if defined(ARMA_USE_OPENMP)
  56. {
  57. const int n_threads = int( (std::min)( uword(mp_thread_limit::get()), uword(B_n_cols) ) );
  58. for(uword row_A=0; row_A < A_n_rows; ++row_A)
  59. {
  60. tmp.copy_row(A, row_A);
  61. #pragma omp parallel for schedule(static) num_threads(n_threads)
  62. for(uword col_B=0; col_B < B_n_cols; ++col_B)
  63. {
  64. const in_eT2* B_coldata = B.colptr(col_B);
  65. out_eT acc = out_eT(0);
  66. for(uword i=0; i < B_n_rows; ++i)
  67. {
  68. acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
  69. }
  70. if( (use_alpha == false) && (use_beta == false) ) { C.at(row_A,col_B) = acc; }
  71. else if( (use_alpha == true ) && (use_beta == false) ) { C.at(row_A,col_B) = alpha*acc; }
  72. else if( (use_alpha == false) && (use_beta == true ) ) { C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); }
  73. else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); }
  74. }
  75. }
  76. }
  77. #endif
  78. }
  79. else
  80. {
  81. for(uword row_A=0; row_A < A_n_rows; ++row_A)
  82. {
  83. tmp.copy_row(A, row_A);
  84. for(uword col_B=0; col_B < B_n_cols; ++col_B)
  85. {
  86. const in_eT2* B_coldata = B.colptr(col_B);
  87. out_eT acc = out_eT(0);
  88. for(uword i=0; i < B_n_rows; ++i)
  89. {
  90. acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
  91. }
  92. if( (use_alpha == false) && (use_beta == false) ) { C.at(row_A,col_B) = acc; }
  93. else if( (use_alpha == true ) && (use_beta == false) ) { C.at(row_A,col_B) = alpha*acc; }
  94. else if( (use_alpha == false) && (use_beta == true ) ) { C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); }
  95. else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); }
  96. }
  97. }
  98. }
  99. }
  100. else
  101. if( (do_trans_A == true) && (do_trans_B == false) )
  102. {
  103. #if defined(ARMA_USE_OPENMP)
  104. const bool use_mp = (B_n_cols >= 2) && (B.n_elem >= 8192) && (mp_thread_limit::in_parallel() == false);
  105. #else
  106. const bool use_mp = false;
  107. #endif
  108. if(use_mp)
  109. {
  110. #if defined(ARMA_USE_OPENMP)
  111. {
  112. const int n_threads = int( (std::min)( uword(mp_thread_limit::get()), uword(B_n_cols) ) );
  113. for(uword col_A=0; col_A < A_n_cols; ++col_A)
  114. {
  115. // col_A is interpreted as row_A when storing the results in matrix C
  116. const in_eT1* A_coldata = A.colptr(col_A);
  117. #pragma omp parallel for schedule(static) num_threads(n_threads)
  118. for(uword col_B=0; col_B < B_n_cols; ++col_B)
  119. {
  120. const in_eT2* B_coldata = B.colptr(col_B);
  121. out_eT acc = out_eT(0);
  122. for(uword i=0; i < B_n_rows; ++i)
  123. {
  124. acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
  125. }
  126. if( (use_alpha == false) && (use_beta == false) ) { C.at(col_A,col_B) = acc; }
  127. else if( (use_alpha == true ) && (use_beta == false) ) { C.at(col_A,col_B) = alpha*acc; }
  128. else if( (use_alpha == false) && (use_beta == true ) ) { C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); }
  129. else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); }
  130. }
  131. }
  132. }
  133. #endif
  134. }
  135. else
  136. {
  137. for(uword col_A=0; col_A < A_n_cols; ++col_A)
  138. {
  139. // col_A is interpreted as row_A when storing the results in matrix C
  140. const in_eT1* A_coldata = A.colptr(col_A);
  141. for(uword col_B=0; col_B < B_n_cols; ++col_B)
  142. {
  143. const in_eT2* B_coldata = B.colptr(col_B);
  144. out_eT acc = out_eT(0);
  145. for(uword i=0; i < B_n_rows; ++i)
  146. {
  147. acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
  148. }
  149. if( (use_alpha == false) && (use_beta == false) ) { C.at(col_A,col_B) = acc; }
  150. else if( (use_alpha == true ) && (use_beta == false) ) { C.at(col_A,col_B) = alpha*acc; }
  151. else if( (use_alpha == false) && (use_beta == true ) ) { C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); }
  152. else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); }
  153. }
  154. }
  155. }
  156. }
  157. else
  158. if( (do_trans_A == false) && (do_trans_B == true) )
  159. {
  160. Mat<in_eT2> B_tmp;
  161. op_strans::apply_mat_noalias(B_tmp, B);
  162. gemm_mixed_large<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
  163. }
  164. else
  165. if( (do_trans_A == true) && (do_trans_B == true) )
  166. {
  167. // mat B_tmp = trans(B);
  168. // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
  169. // By using the trans(A)*trans(B) = trans(B*A) equivalency,
  170. // transpose operations are not needed
  171. podarray<in_eT2> tmp(B_n_cols);
  172. in_eT2* B_rowdata = tmp.memptr();
  173. for(uword row_B=0; row_B < B_n_rows; ++row_B)
  174. {
  175. tmp.copy_row(B, row_B);
  176. for(uword col_A=0; col_A < A_n_cols; ++col_A)
  177. {
  178. const in_eT1* A_coldata = A.colptr(col_A);
  179. out_eT acc = out_eT(0);
  180. for(uword i=0; i < A_n_rows; ++i)
  181. {
  182. acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
  183. }
  184. if( (use_alpha == false) && (use_beta == false) ) { C.at(col_A,row_B) = acc; }
  185. else if( (use_alpha == true ) && (use_beta == false) ) { C.at(col_A,row_B) = alpha*acc; }
  186. else if( (use_alpha == false) && (use_beta == true ) ) { C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); }
  187. else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); }
  188. }
  189. }
  190. }
  191. }
  192. };
  193. //! \brief
  194. //! Matrix multplication where the matrices have differing element types.
  195. template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
  196. class gemm_mixed
  197. {
  198. public:
  199. //! immediate multiplication of matrices A and B, storing the result in C
  200. template<typename out_eT, typename in_eT1, typename in_eT2>
  201. inline
  202. static
  203. void
  204. apply
  205. (
  206. Mat<out_eT>& C,
  207. const Mat<in_eT1>& A,
  208. const Mat<in_eT2>& B,
  209. const out_eT alpha = out_eT(1),
  210. const out_eT beta = out_eT(0)
  211. )
  212. {
  213. arma_extra_debug_sigprint();
  214. if((is_cx<in_eT1>::yes && do_trans_A) || (is_cx<in_eT2>::yes && do_trans_B))
  215. {
  216. // better-than-nothing handling of hermitian transpose
  217. Mat<in_eT1> tmp_A;
  218. Mat<in_eT2> tmp_B;
  219. const bool predo_trans_A = ( (do_trans_A == true) && (is_cx<in_eT1>::yes) );
  220. const bool predo_trans_B = ( (do_trans_B == true) && (is_cx<in_eT2>::yes) );
  221. if(predo_trans_A) { op_htrans::apply_mat_noalias(tmp_A, A); }
  222. if(predo_trans_B) { op_htrans::apply_mat_noalias(tmp_B, B); }
  223. const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A;
  224. const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B;
  225. gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
  226. }
  227. else
  228. {
  229. gemm_mixed_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
  230. }
  231. }
  232. };
  233. //! @}