glue_mixed_meat.hpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup glue_mixed
  16. //! @{
  17. //! matrix multiplication with different element types
  18. template<typename T1, typename T2>
  19. inline
  20. void
  21. glue_mixed_times::apply(Mat<typename eT_promoter<T1,T2>::eT>& out, const mtGlue<typename eT_promoter<T1,T2>::eT, T1, T2, glue_mixed_times>& X)
  22. {
  23. arma_extra_debug_sigprint();
  24. typedef typename T1::elem_type in_eT1;
  25. typedef typename T2::elem_type in_eT2;
  26. typedef typename eT_promoter<T1,T2>::eT out_eT;
  27. const partial_unwrap<T1> tmp1(X.A);
  28. const partial_unwrap<T2> tmp2(X.B);
  29. const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
  30. const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
  31. const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
  32. const out_eT alpha = use_alpha ? (upgrade_val<in_eT1,in_eT2>::apply(tmp1.get_val()) * upgrade_val<in_eT1,in_eT2>::apply(tmp2.get_val())) : out_eT(0);
  33. const bool do_trans_A = partial_unwrap<T1>::do_trans;
  34. const bool do_trans_B = partial_unwrap<T2>::do_trans;
  35. arma_debug_assert_trans_mul_size<do_trans_A, do_trans_B>(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
  36. const uword out_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
  37. const uword out_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
  38. const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out);
  39. if(alias == false)
  40. {
  41. out.set_size(out_n_rows, out_n_cols);
  42. gemm_mixed<do_trans_A, do_trans_B, use_alpha, false>::apply(out, A, B, alpha);
  43. }
  44. else
  45. {
  46. Mat<out_eT> tmp(out_n_rows, out_n_cols);
  47. gemm_mixed<do_trans_A, do_trans_B, use_alpha, false>::apply(tmp, A, B, alpha);
  48. out.steal_mem(tmp);
  49. }
  50. }
  51. //! matrix addition with different element types
  52. template<typename T1, typename T2>
  53. inline
  54. void
  55. glue_mixed_plus::apply(Mat<typename eT_promoter<T1,T2>::eT>& out, const mtGlue<typename eT_promoter<T1,T2>::eT, T1, T2, glue_mixed_plus>& X)
  56. {
  57. arma_extra_debug_sigprint();
  58. typedef typename T1::elem_type eT1;
  59. typedef typename T2::elem_type eT2;
  60. typedef typename promote_type<eT1,eT2>::result out_eT;
  61. promote_type<eT1,eT2>::check();
  62. const Proxy<T1> A(X.A);
  63. const Proxy<T2> B(X.B);
  64. arma_debug_assert_same_size(A, B, "addition");
  65. const uword n_rows = A.get_n_rows();
  66. const uword n_cols = A.get_n_cols();
  67. out.set_size(n_rows, n_cols);
  68. out_eT* out_mem = out.memptr();
  69. const uword n_elem = out.n_elem;
  70. const bool use_at = (Proxy<T1>::use_at || Proxy<T2>::use_at);
  71. if(use_at == false)
  72. {
  73. typename Proxy<T1>::ea_type AA = A.get_ea();
  74. typename Proxy<T2>::ea_type BB = B.get_ea();
  75. if(memory::is_aligned(out_mem))
  76. {
  77. memory::mark_as_aligned(out_mem);
  78. for(uword i=0; i<n_elem; ++i)
  79. {
  80. out_mem[i] = upgrade_val<eT1,eT2>::apply(AA[i]) + upgrade_val<eT1,eT2>::apply(BB[i]);
  81. }
  82. }
  83. else
  84. {
  85. for(uword i=0; i<n_elem; ++i)
  86. {
  87. out_mem[i] = upgrade_val<eT1,eT2>::apply(AA[i]) + upgrade_val<eT1,eT2>::apply(BB[i]);
  88. }
  89. }
  90. }
  91. else
  92. {
  93. for(uword col=0; col < n_cols; ++col)
  94. for(uword row=0; row < n_rows; ++row)
  95. {
  96. (*out_mem) = upgrade_val<eT1,eT2>::apply(A.at(row,col)) + upgrade_val<eT1,eT2>::apply(B.at(row,col));
  97. out_mem++;
  98. }
  99. }
  100. }
  101. //! matrix subtraction with different element types
  102. template<typename T1, typename T2>
  103. inline
  104. void
  105. glue_mixed_minus::apply(Mat<typename eT_promoter<T1,T2>::eT>& out, const mtGlue<typename eT_promoter<T1,T2>::eT, T1, T2, glue_mixed_minus>& X)
  106. {
  107. arma_extra_debug_sigprint();
  108. typedef typename T1::elem_type eT1;
  109. typedef typename T2::elem_type eT2;
  110. typedef typename promote_type<eT1,eT2>::result out_eT;
  111. promote_type<eT1,eT2>::check();
  112. const Proxy<T1> A(X.A);
  113. const Proxy<T2> B(X.B);
  114. arma_debug_assert_same_size(A, B, "subtraction");
  115. const uword n_rows = A.get_n_rows();
  116. const uword n_cols = A.get_n_cols();
  117. out.set_size(n_rows, n_cols);
  118. out_eT* out_mem = out.memptr();
  119. const uword n_elem = out.n_elem;
  120. const bool use_at = (Proxy<T1>::use_at || Proxy<T2>::use_at);
  121. if(use_at == false)
  122. {
  123. typename Proxy<T1>::ea_type AA = A.get_ea();
  124. typename Proxy<T2>::ea_type BB = B.get_ea();
  125. if(memory::is_aligned(out_mem))
  126. {
  127. memory::mark_as_aligned(out_mem);
  128. for(uword i=0; i<n_elem; ++i)
  129. {
  130. out_mem[i] = upgrade_val<eT1,eT2>::apply(AA[i]) - upgrade_val<eT1,eT2>::apply(BB[i]);
  131. }
  132. }
  133. else
  134. {
  135. for(uword i=0; i<n_elem; ++i)
  136. {
  137. out_mem[i] = upgrade_val<eT1,eT2>::apply(AA[i]) - upgrade_val<eT1,eT2>::apply(BB[i]);
  138. }
  139. }
  140. }
  141. else
  142. {
  143. for(uword col=0; col < n_cols; ++col)
  144. for(uword row=0; row < n_rows; ++row)
  145. {
  146. (*out_mem) = upgrade_val<eT1,eT2>::apply(A.at(row,col)) - upgrade_val<eT1,eT2>::apply(B.at(row,col));
  147. out_mem++;
  148. }
  149. }
  150. }
  151. //! element-wise matrix division with different element types
  152. template<typename T1, typename T2>
  153. inline
  154. void
  155. glue_mixed_div::apply(Mat<typename eT_promoter<T1,T2>::eT>& out, const mtGlue<typename eT_promoter<T1,T2>::eT, T1, T2, glue_mixed_div>& X)
  156. {
  157. arma_extra_debug_sigprint();
  158. typedef typename T1::elem_type eT1;
  159. typedef typename T2::elem_type eT2;
  160. typedef typename promote_type<eT1,eT2>::result out_eT;
  161. promote_type<eT1,eT2>::check();
  162. const Proxy<T1> A(X.A);
  163. const Proxy<T2> B(X.B);
  164. arma_debug_assert_same_size(A, B, "element-wise division");
  165. const uword n_rows = A.get_n_rows();
  166. const uword n_cols = A.get_n_cols();
  167. out.set_size(n_rows, n_cols);
  168. out_eT* out_mem = out.memptr();
  169. const uword n_elem = out.n_elem;
  170. const bool use_at = (Proxy<T1>::use_at || Proxy<T2>::use_at);
  171. if(use_at == false)
  172. {
  173. typename Proxy<T1>::ea_type AA = A.get_ea();
  174. typename Proxy<T2>::ea_type BB = B.get_ea();
  175. if(memory::is_aligned(out_mem))
  176. {
  177. memory::mark_as_aligned(out_mem);
  178. for(uword i=0; i<n_elem; ++i)
  179. {
  180. out_mem[i] = upgrade_val<eT1,eT2>::apply(AA[i]) / upgrade_val<eT1,eT2>::apply(BB[i]);
  181. }
  182. }
  183. else
  184. {
  185. for(uword i=0; i<n_elem; ++i)
  186. {
  187. out_mem[i] = upgrade_val<eT1,eT2>::apply(AA[i]) / upgrade_val<eT1,eT2>::apply(BB[i]);
  188. }
  189. }
  190. }
  191. else
  192. {
  193. for(uword col=0; col < n_cols; ++col)
  194. for(uword row=0; row < n_rows; ++row)
  195. {
  196. (*out_mem) = upgrade_val<eT1,eT2>::apply(A.at(row,col)) / upgrade_val<eT1,eT2>::apply(B.at(row,col));
  197. out_mem++;
  198. }
  199. }
  200. }
  201. //! element-wise matrix multiplication with different element types
  202. template<typename T1, typename T2>
  203. inline
  204. void
  205. glue_mixed_schur::apply(Mat<typename eT_promoter<T1,T2>::eT>& out, const mtGlue<typename eT_promoter<T1,T2>::eT, T1, T2, glue_mixed_schur>& X)
  206. {
  207. arma_extra_debug_sigprint();
  208. typedef typename T1::elem_type eT1;
  209. typedef typename T2::elem_type eT2;
  210. typedef typename promote_type<eT1,eT2>::result out_eT;
  211. promote_type<eT1,eT2>::check();
  212. const Proxy<T1> A(X.A);
  213. const Proxy<T2> B(X.B);
  214. arma_debug_assert_same_size(A, B, "element-wise multiplication");
  215. const uword n_rows = A.get_n_rows();
  216. const uword n_cols = A.get_n_cols();
  217. out.set_size(n_rows, n_cols);
  218. out_eT* out_mem = out.memptr();
  219. const uword n_elem = out.n_elem;
  220. const bool use_at = (Proxy<T1>::use_at || Proxy<T2>::use_at);
  221. if(use_at == false)
  222. {
  223. typename Proxy<T1>::ea_type AA = A.get_ea();
  224. typename Proxy<T2>::ea_type BB = B.get_ea();
  225. if(memory::is_aligned(out_mem))
  226. {
  227. memory::mark_as_aligned(out_mem);
  228. for(uword i=0; i<n_elem; ++i)
  229. {
  230. out_mem[i] = upgrade_val<eT1,eT2>::apply(AA[i]) * upgrade_val<eT1,eT2>::apply(BB[i]);
  231. }
  232. }
  233. else
  234. {
  235. for(uword i=0; i<n_elem; ++i)
  236. {
  237. out_mem[i] = upgrade_val<eT1,eT2>::apply(AA[i]) * upgrade_val<eT1,eT2>::apply(BB[i]);
  238. }
  239. }
  240. }
  241. else
  242. {
  243. for(uword col=0; col < n_cols; ++col)
  244. for(uword row=0; row < n_rows; ++row)
  245. {
  246. (*out_mem) = upgrade_val<eT1,eT2>::apply(A.at(row,col)) * upgrade_val<eT1,eT2>::apply(B.at(row,col));
  247. out_mem++;
  248. }
  249. }
  250. }
  251. //
  252. //
  253. //
  254. //! cube addition with different element types
  255. template<typename T1, typename T2>
  256. inline
  257. void
  258. glue_mixed_plus::apply(Cube<typename eT_promoter<T1,T2>::eT>& out, const mtGlueCube<typename eT_promoter<T1,T2>::eT, T1, T2, glue_mixed_plus>& X)
  259. {
  260. arma_extra_debug_sigprint();
  261. typedef typename T1::elem_type eT1;
  262. typedef typename T2::elem_type eT2;
  263. typedef typename promote_type<eT1,eT2>::result out_eT;
  264. promote_type<eT1,eT2>::check();
  265. const ProxyCube<T1> A(X.A);
  266. const ProxyCube<T2> B(X.B);
  267. arma_debug_assert_same_size(A, B, "addition");
  268. const uword n_rows = A.get_n_rows();
  269. const uword n_cols = A.get_n_cols();
  270. const uword n_slices = A.get_n_slices();
  271. out.set_size(n_rows, n_cols, n_slices);
  272. out_eT* out_mem = out.memptr();
  273. const uword n_elem = out.n_elem;
  274. const bool use_at = (ProxyCube<T1>::use_at || ProxyCube<T2>::use_at);
  275. if(use_at == false)
  276. {
  277. typename ProxyCube<T1>::ea_type AA = A.get_ea();
  278. typename ProxyCube<T2>::ea_type BB = B.get_ea();
  279. for(uword i=0; i<n_elem; ++i)
  280. {
  281. out_mem[i] = upgrade_val<eT1,eT2>::apply(AA[i]) + upgrade_val<eT1,eT2>::apply(BB[i]);
  282. }
  283. }
  284. else
  285. {
  286. for(uword slice = 0; slice < n_slices; ++slice)
  287. for(uword col = 0; col < n_cols; ++col )
  288. for(uword row = 0; row < n_rows; ++row )
  289. {
  290. (*out_mem) = upgrade_val<eT1,eT2>::apply(A.at(row,col,slice)) + upgrade_val<eT1,eT2>::apply(B.at(row,col,slice));
  291. out_mem++;
  292. }
  293. }
  294. }
  295. //! cube subtraction with different element types
  296. template<typename T1, typename T2>
  297. inline
  298. void
  299. glue_mixed_minus::apply(Cube<typename eT_promoter<T1,T2>::eT>& out, const mtGlueCube<typename eT_promoter<T1,T2>::eT, T1, T2, glue_mixed_minus>& X)
  300. {
  301. arma_extra_debug_sigprint();
  302. typedef typename T1::elem_type eT1;
  303. typedef typename T2::elem_type eT2;
  304. typedef typename promote_type<eT1,eT2>::result out_eT;
  305. promote_type<eT1,eT2>::check();
  306. const ProxyCube<T1> A(X.A);
  307. const ProxyCube<T2> B(X.B);
  308. arma_debug_assert_same_size(A, B, "subtraction");
  309. const uword n_rows = A.get_n_rows();
  310. const uword n_cols = A.get_n_cols();
  311. const uword n_slices = A.get_n_slices();
  312. out.set_size(n_rows, n_cols, n_slices);
  313. out_eT* out_mem = out.memptr();
  314. const uword n_elem = out.n_elem;
  315. const bool use_at = (ProxyCube<T1>::use_at || ProxyCube<T2>::use_at);
  316. if(use_at == false)
  317. {
  318. typename ProxyCube<T1>::ea_type AA = A.get_ea();
  319. typename ProxyCube<T2>::ea_type BB = B.get_ea();
  320. for(uword i=0; i<n_elem; ++i)
  321. {
  322. out_mem[i] = upgrade_val<eT1,eT2>::apply(AA[i]) - upgrade_val<eT1,eT2>::apply(BB[i]);
  323. }
  324. }
  325. else
  326. {
  327. for(uword slice = 0; slice < n_slices; ++slice)
  328. for(uword col = 0; col < n_cols; ++col )
  329. for(uword row = 0; row < n_rows; ++row )
  330. {
  331. (*out_mem) = upgrade_val<eT1,eT2>::apply(A.at(row,col,slice)) - upgrade_val<eT1,eT2>::apply(B.at(row,col,slice));
  332. out_mem++;
  333. }
  334. }
  335. }
  336. //! element-wise cube division with different element types
  337. template<typename T1, typename T2>
  338. inline
  339. void
  340. glue_mixed_div::apply(Cube<typename eT_promoter<T1,T2>::eT>& out, const mtGlueCube<typename eT_promoter<T1,T2>::eT, T1, T2, glue_mixed_div>& X)
  341. {
  342. arma_extra_debug_sigprint();
  343. typedef typename T1::elem_type eT1;
  344. typedef typename T2::elem_type eT2;
  345. typedef typename promote_type<eT1,eT2>::result out_eT;
  346. promote_type<eT1,eT2>::check();
  347. const ProxyCube<T1> A(X.A);
  348. const ProxyCube<T2> B(X.B);
  349. arma_debug_assert_same_size(A, B, "element-wise division");
  350. const uword n_rows = A.get_n_rows();
  351. const uword n_cols = A.get_n_cols();
  352. const uword n_slices = A.get_n_slices();
  353. out.set_size(n_rows, n_cols, n_slices);
  354. out_eT* out_mem = out.memptr();
  355. const uword n_elem = out.n_elem;
  356. const bool use_at = (ProxyCube<T1>::use_at || ProxyCube<T2>::use_at);
  357. if(use_at == false)
  358. {
  359. typename ProxyCube<T1>::ea_type AA = A.get_ea();
  360. typename ProxyCube<T2>::ea_type BB = B.get_ea();
  361. for(uword i=0; i<n_elem; ++i)
  362. {
  363. out_mem[i] = upgrade_val<eT1,eT2>::apply(AA[i]) / upgrade_val<eT1,eT2>::apply(BB[i]);
  364. }
  365. }
  366. else
  367. {
  368. for(uword slice = 0; slice < n_slices; ++slice)
  369. for(uword col = 0; col < n_cols; ++col )
  370. for(uword row = 0; row < n_rows; ++row )
  371. {
  372. (*out_mem) = upgrade_val<eT1,eT2>::apply(A.at(row,col,slice)) / upgrade_val<eT1,eT2>::apply(B.at(row,col,slice));
  373. out_mem++;
  374. }
  375. }
  376. }
  377. //! element-wise cube multiplication with different element types
  378. template<typename T1, typename T2>
  379. inline
  380. void
  381. glue_mixed_schur::apply(Cube<typename eT_promoter<T1,T2>::eT>& out, const mtGlueCube<typename eT_promoter<T1,T2>::eT, T1, T2, glue_mixed_schur>& X)
  382. {
  383. arma_extra_debug_sigprint();
  384. typedef typename T1::elem_type eT1;
  385. typedef typename T2::elem_type eT2;
  386. typedef typename promote_type<eT1,eT2>::result out_eT;
  387. promote_type<eT1,eT2>::check();
  388. const ProxyCube<T1> A(X.A);
  389. const ProxyCube<T2> B(X.B);
  390. arma_debug_assert_same_size(A, B, "element-wise multiplication");
  391. const uword n_rows = A.get_n_rows();
  392. const uword n_cols = A.get_n_cols();
  393. const uword n_slices = A.get_n_slices();
  394. out.set_size(n_rows, n_cols, n_slices);
  395. out_eT* out_mem = out.memptr();
  396. const uword n_elem = out.n_elem;
  397. const bool use_at = (ProxyCube<T1>::use_at || ProxyCube<T2>::use_at);
  398. if(use_at == false)
  399. {
  400. typename ProxyCube<T1>::ea_type AA = A.get_ea();
  401. typename ProxyCube<T2>::ea_type BB = B.get_ea();
  402. for(uword i=0; i<n_elem; ++i)
  403. {
  404. out_mem[i] = upgrade_val<eT1,eT2>::apply(AA[i]) * upgrade_val<eT1,eT2>::apply(BB[i]);
  405. }
  406. }
  407. else
  408. {
  409. for(uword slice = 0; slice < n_slices; ++slice)
  410. for(uword col = 0; col < n_cols; ++col )
  411. for(uword row = 0; row < n_rows; ++row )
  412. {
  413. (*out_mem) = upgrade_val<eT1,eT2>::apply(A.at(row,col,slice)) * upgrade_val<eT1,eT2>::apply(B.at(row,col,slice));
  414. out_mem++;
  415. }
  416. }
  417. }
  418. //! @}