glue_times_meat.hpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup glue_times
  16. //! @{
  17. template<bool do_inv_detect>
  18. template<typename T1, typename T2>
  19. arma_hot
  20. inline
  21. void
  22. glue_times_redirect2_helper<do_inv_detect>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
  23. {
  24. arma_extra_debug_sigprint();
  25. typedef typename T1::elem_type eT;
  26. const partial_unwrap<T1> tmp1(X.A);
  27. const partial_unwrap<T2> tmp2(X.B);
  28. const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
  29. const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
  30. const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
  31. const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
  32. const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out);
  33. if(alias == false)
  34. {
  35. glue_times::apply
  36. <
  37. eT,
  38. partial_unwrap<T1>::do_trans,
  39. partial_unwrap<T2>::do_trans,
  40. (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times)
  41. >
  42. (out, A, B, alpha);
  43. }
  44. else
  45. {
  46. Mat<eT> tmp;
  47. glue_times::apply
  48. <
  49. eT,
  50. partial_unwrap<T1>::do_trans,
  51. partial_unwrap<T2>::do_trans,
  52. (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times)
  53. >
  54. (tmp, A, B, alpha);
  55. out.steal_mem(tmp);
  56. }
  57. }
  58. template<typename T1, typename T2>
  59. arma_hot
  60. inline
  61. void
  62. glue_times_redirect2_helper<true>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
  63. {
  64. arma_extra_debug_sigprint();
  65. typedef typename T1::elem_type eT;
  66. if(strip_inv<T1>::do_inv == true)
  67. {
  68. // replace inv(A)*B with solve(A,B)
  69. arma_extra_debug_print("glue_times_redirect<2>::apply(): detected inv(A)*B");
  70. const strip_inv<T1> A_strip(X.A);
  71. Mat<eT> A = A_strip.M;
  72. arma_debug_check( (A.is_square() == false), "inv(): given matrix must be square sized" );
  73. if(strip_inv<T1>::do_inv_sympd)
  74. {
  75. // if(auxlib::rudimentary_sym_check(A) == false)
  76. // {
  77. // if(is_cx<eT>::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); }
  78. // if(is_cx<eT>::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); }
  79. //
  80. // out.soft_reset();
  81. // arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead");
  82. //
  83. // return;
  84. // }
  85. if( (arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false) )
  86. {
  87. if(is_cx<eT>::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); }
  88. if(is_cx<eT>::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); }
  89. }
  90. }
  91. const unwrap_check<T2> B_tmp(X.B, out);
  92. const Mat<eT>& B = B_tmp.M;
  93. arma_debug_assert_mul_size(A, B, "matrix multiplication");
  94. // TODO: detect sympd via sympd_helper::guess_sympd(A) ?
  95. #if defined(ARMA_OPTIMISE_SYMPD)
  96. const bool status = (strip_inv<T1>::do_inv_sympd) ? auxlib::solve_sympd_fast(out, A, B) : auxlib::solve_square_fast(out, A, B);
  97. #else
  98. const bool status = auxlib::solve_square_fast(out, A, B);
  99. #endif
  100. if(status == false)
  101. {
  102. out.soft_reset();
  103. arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead");
  104. }
  105. return;
  106. }
  107. #if defined(ARMA_OPTIMISE_SYMPD)
  108. {
  109. if(strip_inv<T2>::do_inv_sympd)
  110. {
  111. // replace A*inv_sympd(B) with trans( solve(trans(B),trans(A)) )
  112. // transpose of B is avoided as B is explicitly marked as symmetric
  113. arma_extra_debug_print("glue_times_redirect<2>::apply(): detected A*inv_sympd(B)");
  114. const Mat<eT> At = trans(X.A);
  115. const strip_inv<T2> B_strip(X.B);
  116. Mat<eT> B = B_strip.M;
  117. arma_debug_check( (B.is_square() == false), "inv_sympd(): given matrix must be square sized" );
  118. // if(auxlib::rudimentary_sym_check(B) == false)
  119. // {
  120. // if(is_cx<eT>::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); }
  121. // if(is_cx<eT>::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); }
  122. //
  123. // out.soft_reset();
  124. // arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead");
  125. //
  126. // return;
  127. // }
  128. if( (arma_config::debug) && (auxlib::rudimentary_sym_check(B) == false) )
  129. {
  130. if(is_cx<eT>::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); }
  131. if(is_cx<eT>::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); }
  132. }
  133. arma_debug_assert_mul_size(At.n_cols, At.n_rows, B.n_rows, B.n_cols, "matrix multiplication");
  134. const bool status = auxlib::solve_sympd_fast(out, B, At);
  135. if(status == false)
  136. {
  137. out.soft_reset();
  138. arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead");
  139. }
  140. out = trans(out);
  141. return;
  142. }
  143. }
  144. #endif
  145. glue_times_redirect2_helper<false>::apply(out, X);
  146. }
  147. template<bool do_inv_detect>
  148. template<typename T1, typename T2, typename T3>
  149. arma_hot
  150. inline
  151. void
  152. glue_times_redirect3_helper<do_inv_detect>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
  153. {
  154. arma_extra_debug_sigprint();
  155. typedef typename T1::elem_type eT;
  156. // we have exactly 3 objects
  157. // hence we can safely expand X as X.A.A, X.A.B and X.B
  158. const partial_unwrap<T1> tmp1(X.A.A);
  159. const partial_unwrap<T2> tmp2(X.A.B);
  160. const partial_unwrap<T3> tmp3(X.B );
  161. const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
  162. const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
  163. const typename partial_unwrap<T3>::stored_type& C = tmp3.M;
  164. const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times;
  165. const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0);
  166. const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out) || tmp3.is_alias(out);
  167. if(alias == false)
  168. {
  169. glue_times::apply
  170. <
  171. eT,
  172. partial_unwrap<T1>::do_trans,
  173. partial_unwrap<T2>::do_trans,
  174. partial_unwrap<T3>::do_trans,
  175. (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times)
  176. >
  177. (out, A, B, C, alpha);
  178. }
  179. else
  180. {
  181. Mat<eT> tmp;
  182. glue_times::apply
  183. <
  184. eT,
  185. partial_unwrap<T1>::do_trans,
  186. partial_unwrap<T2>::do_trans,
  187. partial_unwrap<T3>::do_trans,
  188. (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times)
  189. >
  190. (tmp, A, B, C, alpha);
  191. out.steal_mem(tmp);
  192. }
  193. }
  194. template<typename T1, typename T2, typename T3>
  195. arma_hot
  196. inline
  197. void
  198. glue_times_redirect3_helper<true>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
  199. {
  200. arma_extra_debug_sigprint();
  201. typedef typename T1::elem_type eT;
  202. if(strip_inv<T1>::do_inv == true)
  203. {
  204. // replace inv(A)*B*C with solve(A,B*C);
  205. arma_extra_debug_print("glue_times_redirect<3>::apply(): detected inv(A)*B*C");
  206. const strip_inv<T1> A_strip(X.A.A);
  207. Mat<eT> A = A_strip.M;
  208. arma_debug_check( (A.is_square() == false), "inv(): given matrix must be square sized" );
  209. const partial_unwrap<T2> tmp2(X.A.B);
  210. const partial_unwrap<T3> tmp3(X.B );
  211. const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
  212. const typename partial_unwrap<T3>::stored_type& C = tmp3.M;
  213. const bool use_alpha = partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times;
  214. const eT alpha = use_alpha ? (tmp2.get_val() * tmp3.get_val()) : eT(0);
  215. Mat<eT> BC;
  216. glue_times::apply
  217. <
  218. eT,
  219. partial_unwrap<T2>::do_trans,
  220. partial_unwrap<T3>::do_trans,
  221. (partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times)
  222. >
  223. (BC, B, C, alpha);
  224. arma_debug_assert_mul_size(A, BC, "matrix multiplication");
  225. // TODO: detect sympd via sympd_helper::guess_sympd(A) ?
  226. #if defined(ARMA_OPTIMISE_SYMPD)
  227. const bool status = (strip_inv<T1>::do_inv_sympd) ? auxlib::solve_sympd_fast(out, A, BC) : auxlib::solve_square_fast(out, A, BC);
  228. #else
  229. const bool status = auxlib::solve_square_fast(out, A, BC);
  230. #endif
  231. if(status == false)
  232. {
  233. out.soft_reset();
  234. arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead");
  235. }
  236. return;
  237. }
  238. if(strip_inv<T2>::do_inv == true)
  239. {
  240. // replace A*inv(B)*C with A*solve(B,C)
  241. arma_extra_debug_print("glue_times_redirect<3>::apply(): detected A*inv(B)*C");
  242. const strip_inv<T2> B_strip(X.A.B);
  243. Mat<eT> B = B_strip.M;
  244. arma_debug_check( (B.is_square() == false), "inv(): given matrix must be square sized" );
  245. const unwrap<T3> C_tmp(X.B);
  246. const Mat<eT>& C = C_tmp.M;
  247. arma_debug_assert_mul_size(B, C, "matrix multiplication");
  248. Mat<eT> solve_result;
  249. #if defined(ARMA_OPTIMISE_SYMPD)
  250. const bool status = (strip_inv<T2>::do_inv_sympd) ? auxlib::solve_sympd_fast(solve_result, B, C) : auxlib::solve_square_fast(solve_result, B, C);
  251. #else
  252. const bool status = auxlib::solve_square_fast(solve_result, B, C);
  253. #endif
  254. if(status == false)
  255. {
  256. out.soft_reset();
  257. arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead");
  258. return;
  259. }
  260. const partial_unwrap_check<T1> tmp1(X.A.A, out);
  261. const typename partial_unwrap_check<T1>::stored_type& A = tmp1.M;
  262. const bool use_alpha = partial_unwrap_check<T1>::do_times;
  263. const eT alpha = use_alpha ? tmp1.get_val() : eT(0);
  264. glue_times::apply
  265. <
  266. eT,
  267. partial_unwrap_check<T1>::do_trans,
  268. false,
  269. partial_unwrap_check<T1>::do_times
  270. >
  271. (out, A, solve_result, alpha);
  272. return;
  273. }
  274. glue_times_redirect3_helper<false>::apply(out, X);
  275. }
  276. template<uword N>
  277. template<typename T1, typename T2>
  278. arma_hot
  279. inline
  280. void
  281. glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
  282. {
  283. arma_extra_debug_sigprint();
  284. typedef typename T1::elem_type eT;
  285. const partial_unwrap<T1> tmp1(X.A);
  286. const partial_unwrap<T2> tmp2(X.B);
  287. const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
  288. const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
  289. const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
  290. const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
  291. const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out);
  292. if(alias == false)
  293. {
  294. glue_times::apply
  295. <
  296. eT,
  297. partial_unwrap<T1>::do_trans,
  298. partial_unwrap<T2>::do_trans,
  299. (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times)
  300. >
  301. (out, A, B, alpha);
  302. }
  303. else
  304. {
  305. Mat<eT> tmp;
  306. glue_times::apply
  307. <
  308. eT,
  309. partial_unwrap<T1>::do_trans,
  310. partial_unwrap<T2>::do_trans,
  311. (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times)
  312. >
  313. (tmp, A, B, alpha);
  314. out.steal_mem(tmp);
  315. }
  316. }
  317. template<typename T1, typename T2>
  318. arma_hot
  319. inline
  320. void
  321. glue_times_redirect<2>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
  322. {
  323. arma_extra_debug_sigprint();
  324. typedef typename T1::elem_type eT;
  325. glue_times_redirect2_helper< is_supported_blas_type<eT>::value >::apply(out, X);
  326. }
  327. template<typename T1, typename T2, typename T3>
  328. arma_hot
  329. inline
  330. void
  331. glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
  332. {
  333. arma_extra_debug_sigprint();
  334. typedef typename T1::elem_type eT;
  335. glue_times_redirect3_helper< is_supported_blas_type<eT>::value >::apply(out, X);
  336. }
  337. template<typename T1, typename T2, typename T3, typename T4>
  338. arma_hot
  339. inline
  340. void
  341. glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X)
  342. {
  343. arma_extra_debug_sigprint();
  344. typedef typename T1::elem_type eT;
  345. // there is exactly 4 objects
  346. // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B
  347. const partial_unwrap<T1> tmp1(X.A.A.A);
  348. const partial_unwrap<T2> tmp2(X.A.A.B);
  349. const partial_unwrap<T3> tmp3(X.A.B );
  350. const partial_unwrap<T4> tmp4(X.B );
  351. const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
  352. const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
  353. const typename partial_unwrap<T3>::stored_type& C = tmp3.M;
  354. const typename partial_unwrap<T4>::stored_type& D = tmp4.M;
  355. const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times || partial_unwrap<T4>::do_times;
  356. const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0);
  357. const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out) || tmp3.is_alias(out) || tmp4.is_alias(out);
  358. if(alias == false)
  359. {
  360. glue_times::apply
  361. <
  362. eT,
  363. partial_unwrap<T1>::do_trans,
  364. partial_unwrap<T2>::do_trans,
  365. partial_unwrap<T3>::do_trans,
  366. partial_unwrap<T4>::do_trans,
  367. (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times || partial_unwrap<T4>::do_times)
  368. >
  369. (out, A, B, C, D, alpha);
  370. }
  371. else
  372. {
  373. Mat<eT> tmp;
  374. glue_times::apply
  375. <
  376. eT,
  377. partial_unwrap<T1>::do_trans,
  378. partial_unwrap<T2>::do_trans,
  379. partial_unwrap<T3>::do_trans,
  380. partial_unwrap<T4>::do_trans,
  381. (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times || partial_unwrap<T4>::do_times)
  382. >
  383. (tmp, A, B, C, D, alpha);
  384. out.steal_mem(tmp);
  385. }
  386. }
  387. template<typename T1, typename T2>
  388. arma_hot
  389. inline
  390. void
  391. glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
  392. {
  393. arma_extra_debug_sigprint();
  394. const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
  395. arma_extra_debug_print(arma_str::format("N_mat = %d") % N_mat);
  396. glue_times_redirect<N_mat>::apply(out, X);
  397. }
  398. template<typename T1>
  399. arma_hot
  400. inline
  401. void
  402. glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
  403. {
  404. arma_extra_debug_sigprint();
  405. out = out * X;
  406. }
  407. template<typename T1, typename T2>
  408. arma_hot
  409. inline
  410. void
  411. glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const sword sign)
  412. {
  413. arma_extra_debug_sigprint();
  414. typedef typename T1::elem_type eT;
  415. typedef typename get_pod_type<eT>::result T;
  416. if( (is_outer_product<T1>::value) || (has_op_inv<T1>::value) || (has_op_inv<T2>::value) || (has_op_inv_sympd<T1>::value) || (has_op_inv_sympd<T2>::value) )
  417. {
  418. // partial workaround for corner cases
  419. const Mat<eT> tmp(X);
  420. if(sign > sword(0)) { out += tmp; } else { out -= tmp; }
  421. return;
  422. }
  423. const partial_unwrap_check<T1> tmp1(X.A, out);
  424. const partial_unwrap_check<T2> tmp2(X.B, out);
  425. typedef typename partial_unwrap_check<T1>::stored_type TA;
  426. typedef typename partial_unwrap_check<T2>::stored_type TB;
  427. const TA& A = tmp1.M;
  428. const TB& B = tmp2.M;
  429. const bool do_trans_A = partial_unwrap_check<T1>::do_trans;
  430. const bool do_trans_B = partial_unwrap_check<T2>::do_trans;
  431. const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times || (sign < sword(0));
  432. const eT alpha = use_alpha ? ( tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ) ) : eT(0);
  433. arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
  434. const uword result_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols);
  435. const uword result_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows);
  436. arma_debug_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, ( (sign > sword(0)) ? "addition" : "subtraction" ) );
  437. if(out.n_elem == 0)
  438. {
  439. return;
  440. }
  441. if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
  442. {
  443. if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
  444. else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
  445. else { gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1)); }
  446. }
  447. else
  448. if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
  449. {
  450. if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
  451. else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
  452. else { gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1)); }
  453. }
  454. else
  455. if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
  456. {
  457. if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
  458. else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
  459. else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<true, false, true>::apply(out, A, alpha, eT(1)); }
  460. else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::yes) ) { herk<true, false, true>::apply(out, A, T(0), T(1)); }
  461. else { gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1)); }
  462. }
  463. else
  464. if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
  465. {
  466. if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
  467. else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
  468. else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<true, true, true>::apply(out, A, alpha, eT(1)); }
  469. else { gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1)); }
  470. }
  471. else
  472. if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
  473. {
  474. if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
  475. else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
  476. else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<false, false, true>::apply(out, A, alpha, eT(1)); }
  477. else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::yes) ) { herk<false, false, true>::apply(out, A, T(0), T(1)); }
  478. else { gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1)); }
  479. }
  480. else
  481. if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
  482. {
  483. if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
  484. else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
  485. else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<false, true, true>::apply(out, A, alpha, eT(1)); }
  486. else { gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1)); }
  487. }
  488. else
  489. if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
  490. {
  491. if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
  492. else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
  493. else { gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1)); }
  494. }
  495. else
  496. if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
  497. {
  498. if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
  499. else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
  500. else { gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1)); }
  501. }
  502. }
  503. template<typename eT, const bool do_trans_A, const bool do_trans_B, typename TA, typename TB>
  504. arma_inline
  505. uword
  506. glue_times::mul_storage_cost(const TA& A, const TB& B)
  507. {
  508. const uword final_A_n_rows = (do_trans_A == false) ? ( TA::is_row ? 1 : A.n_rows ) : ( TA::is_col ? 1 : A.n_cols );
  509. const uword final_B_n_cols = (do_trans_B == false) ? ( TB::is_col ? 1 : B.n_cols ) : ( TB::is_row ? 1 : B.n_rows );
  510. return final_A_n_rows * final_B_n_cols;
  511. }
  512. template
  513. <
  514. typename eT,
  515. const bool do_trans_A,
  516. const bool do_trans_B,
  517. const bool use_alpha,
  518. typename TA,
  519. typename TB
  520. >
  521. arma_hot
  522. inline
  523. void
  524. glue_times::apply
  525. (
  526. Mat<eT>& out,
  527. const TA& A,
  528. const TB& B,
  529. const eT alpha
  530. )
  531. {
  532. arma_extra_debug_sigprint();
  533. //arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
  534. arma_debug_assert_trans_mul_size<do_trans_A, do_trans_B>(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
  535. const uword final_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols);
  536. const uword final_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows);
  537. out.set_size(final_n_rows, final_n_cols);
  538. if( (A.n_elem == 0) || (B.n_elem == 0) )
  539. {
  540. out.zeros();
  541. return;
  542. }
  543. if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
  544. {
  545. if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); }
  546. else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); }
  547. else { gemm<false, false, false, false>::apply(out, A, B ); }
  548. }
  549. else
  550. if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
  551. {
  552. if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); }
  553. else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); }
  554. else { gemm<false, false, true, false>::apply(out, A, B, alpha); }
  555. }
  556. else
  557. if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
  558. {
  559. if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); }
  560. else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); }
  561. else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<true, false, false>::apply(out, A ); }
  562. else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::yes) ) { herk<true, false, false>::apply(out, A ); }
  563. else { gemm<true, false, false, false>::apply(out, A, B ); }
  564. }
  565. else
  566. if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
  567. {
  568. if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); }
  569. else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); }
  570. else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<true, true, false>::apply(out, A, alpha); }
  571. else { gemm<true, false, true, false>::apply(out, A, B, alpha); }
  572. }
  573. else
  574. if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
  575. {
  576. if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); }
  577. else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); }
  578. else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<false, false, false>::apply(out, A ); }
  579. else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::yes) ) { herk<false, false, false>::apply(out, A ); }
  580. else { gemm<false, true, false, false>::apply(out, A, B ); }
  581. }
  582. else
  583. if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
  584. {
  585. if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); }
  586. else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); }
  587. else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<false, true, false>::apply(out, A, alpha); }
  588. else { gemm<false, true, true, false>::apply(out, A, B, alpha); }
  589. }
  590. else
  591. if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
  592. {
  593. if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); }
  594. else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); }
  595. else { gemm<true, true, false, false>::apply(out, A, B ); }
  596. }
  597. else
  598. if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
  599. {
  600. if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); }
  601. else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); }
  602. else { gemm<true, true, true, false>::apply(out, A, B, alpha); }
  603. }
  604. }
  605. template
  606. <
  607. typename eT,
  608. const bool do_trans_A,
  609. const bool do_trans_B,
  610. const bool do_trans_C,
  611. const bool use_alpha,
  612. typename TA,
  613. typename TB,
  614. typename TC
  615. >
  616. arma_hot
  617. inline
  618. void
  619. glue_times::apply
  620. (
  621. Mat<eT>& out,
  622. const TA& A,
  623. const TB& B,
  624. const TC& C,
  625. const eT alpha
  626. )
  627. {
  628. arma_extra_debug_sigprint();
  629. Mat<eT> tmp;
  630. const uword storage_cost_AB = glue_times::mul_storage_cost<eT, do_trans_A, do_trans_B>(A, B);
  631. const uword storage_cost_BC = glue_times::mul_storage_cost<eT, do_trans_B, do_trans_C>(B, C);
  632. if(storage_cost_AB <= storage_cost_BC)
  633. {
  634. // out = (A*B)*C
  635. glue_times::apply<eT, do_trans_A, do_trans_B, use_alpha>(tmp, A, B, alpha);
  636. glue_times::apply<eT, false, do_trans_C, false >(out, tmp, C, eT(0));
  637. }
  638. else
  639. {
  640. // out = A*(B*C)
  641. glue_times::apply<eT, do_trans_B, do_trans_C, use_alpha>(tmp, B, C, alpha);
  642. glue_times::apply<eT, do_trans_A, false, false >(out, A, tmp, eT(0));
  643. }
  644. }
  645. template
  646. <
  647. typename eT,
  648. const bool do_trans_A,
  649. const bool do_trans_B,
  650. const bool do_trans_C,
  651. const bool do_trans_D,
  652. const bool use_alpha,
  653. typename TA,
  654. typename TB,
  655. typename TC,
  656. typename TD
  657. >
  658. arma_hot
  659. inline
  660. void
  661. glue_times::apply
  662. (
  663. Mat<eT>& out,
  664. const TA& A,
  665. const TB& B,
  666. const TC& C,
  667. const TD& D,
  668. const eT alpha
  669. )
  670. {
  671. arma_extra_debug_sigprint();
  672. Mat<eT> tmp;
  673. const uword storage_cost_AC = glue_times::mul_storage_cost<eT, do_trans_A, do_trans_C>(A, C);
  674. const uword storage_cost_BD = glue_times::mul_storage_cost<eT, do_trans_B, do_trans_D>(B, D);
  675. if(storage_cost_AC <= storage_cost_BD)
  676. {
  677. // out = (A*B*C)*D
  678. glue_times::apply<eT, do_trans_A, do_trans_B, do_trans_C, use_alpha>(tmp, A, B, C, alpha);
  679. glue_times::apply<eT, false, do_trans_D, false>(out, tmp, D, eT(0));
  680. }
  681. else
  682. {
  683. // out = A*(B*C*D)
  684. glue_times::apply<eT, do_trans_B, do_trans_C, do_trans_D, use_alpha>(tmp, B, C, D, alpha);
  685. glue_times::apply<eT, do_trans_A, false, false>(out, A, tmp, eT(0));
  686. }
  687. }
  688. //
  689. // glue_times_diag
  690. template<typename T1, typename T2>
  691. arma_hot
  692. inline
  693. void
  694. glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X)
  695. {
  696. arma_extra_debug_sigprint();
  697. typedef typename T1::elem_type eT;
  698. const strip_diagmat<T1> S1(X.A);
  699. const strip_diagmat<T2> S2(X.B);
  700. typedef typename strip_diagmat<T1>::stored_type T1_stripped;
  701. typedef typename strip_diagmat<T2>::stored_type T2_stripped;
  702. if( (strip_diagmat<T1>::do_diagmat == true) && (strip_diagmat<T2>::do_diagmat == false) )
  703. {
  704. arma_extra_debug_print("glue_times_diag::apply(): diagmat(A) * B");
  705. const diagmat_proxy_check<T1_stripped> A(S1.M, out);
  706. const unwrap_check<T2> tmp(X.B, out);
  707. const Mat<eT>& B = tmp.M;
  708. const uword A_n_rows = A.n_rows;
  709. const uword A_n_cols = A.n_cols;
  710. const uword A_length = (std::min)(A_n_rows, A_n_cols);
  711. const uword B_n_rows = B.n_rows;
  712. const uword B_n_cols = B.n_cols;
  713. arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication");
  714. out.zeros(A_n_rows, B_n_cols);
  715. for(uword col=0; col < B_n_cols; ++col)
  716. {
  717. eT* out_coldata = out.colptr(col);
  718. const eT* B_coldata = B.colptr(col);
  719. for(uword i=0; i < A_length; ++i)
  720. {
  721. out_coldata[i] = A[i] * B_coldata[i];
  722. }
  723. }
  724. }
  725. else
  726. if( (strip_diagmat<T1>::do_diagmat == false) && (strip_diagmat<T2>::do_diagmat == true) )
  727. {
  728. arma_extra_debug_print("glue_times_diag::apply(): A * diagmat(B)");
  729. const unwrap_check<T1> tmp(X.A, out);
  730. const Mat<eT>& A = tmp.M;
  731. const diagmat_proxy_check<T2_stripped> B(S2.M, out);
  732. const uword A_n_rows = A.n_rows;
  733. const uword A_n_cols = A.n_cols;
  734. const uword B_n_rows = B.n_rows;
  735. const uword B_n_cols = B.n_cols;
  736. const uword B_length = (std::min)(B_n_rows, B_n_cols);
  737. arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication");
  738. out.zeros(A_n_rows, B_n_cols);
  739. for(uword col=0; col < B_length; ++col)
  740. {
  741. const eT val = B[col];
  742. eT* out_coldata = out.colptr(col);
  743. const eT* A_coldata = A.colptr(col);
  744. for(uword i=0; i < A_n_rows; ++i)
  745. {
  746. out_coldata[i] = A_coldata[i] * val;
  747. }
  748. }
  749. }
  750. else
  751. if( (strip_diagmat<T1>::do_diagmat == true) && (strip_diagmat<T2>::do_diagmat == true) )
  752. {
  753. arma_extra_debug_print("glue_times_diag::apply(): diagmat(A) * diagmat(B)");
  754. const diagmat_proxy_check<T1_stripped> A(S1.M, out);
  755. const diagmat_proxy_check<T2_stripped> B(S2.M, out);
  756. arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
  757. out.zeros(A.n_rows, B.n_cols);
  758. const uword A_length = (std::min)(A.n_rows, A.n_cols);
  759. const uword B_length = (std::min)(B.n_rows, B.n_cols);
  760. const uword N = (std::min)(A_length, B_length);
  761. for(uword i=0; i < N; ++i)
  762. {
  763. out.at(i,i) = A[i] * B[i];
  764. }
  765. }
  766. }
  767. //! @}