operator_times.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup operator_times
  16. //! @{
  17. //! Base * scalar
  18. template<typename T1>
  19. arma_inline
  20. typename enable_if2< is_arma_type<T1>::value, const eOp<T1, eop_scalar_times> >::result
  21. operator*
  22. (const T1& X, const typename T1::elem_type k)
  23. {
  24. arma_extra_debug_sigprint();
  25. return eOp<T1, eop_scalar_times>(X,k);
  26. }
  27. //! scalar * Base
  28. template<typename T1>
  29. arma_inline
  30. typename enable_if2< is_arma_type<T1>::value, const eOp<T1, eop_scalar_times> >::result
  31. operator*
  32. (const typename T1::elem_type k, const T1& X)
  33. {
  34. arma_extra_debug_sigprint();
  35. return eOp<T1, eop_scalar_times>(X,k); // NOTE: order is swapped
  36. }
  37. //! non-complex Base * complex scalar
  38. template<typename T1>
  39. arma_inline
  40. typename
  41. enable_if2
  42. <
  43. (is_arma_type<T1>::value && is_cx<typename T1::elem_type>::no),
  44. const mtOp<typename std::complex<typename T1::pod_type>, T1, op_cx_scalar_times>
  45. >::result
  46. operator*
  47. (
  48. const T1& X,
  49. const std::complex<typename T1::pod_type>& k
  50. )
  51. {
  52. arma_extra_debug_sigprint();
  53. return mtOp<typename std::complex<typename T1::pod_type>, T1, op_cx_scalar_times>('j', X, k);
  54. }
  55. //! complex scalar * non-complex Base
  56. template<typename T1>
  57. arma_inline
  58. typename
  59. enable_if2
  60. <
  61. (is_arma_type<T1>::value && is_cx<typename T1::elem_type>::no),
  62. const mtOp<typename std::complex<typename T1::pod_type>, T1, op_cx_scalar_times>
  63. >::result
  64. operator*
  65. (
  66. const std::complex<typename T1::pod_type>& k,
  67. const T1& X
  68. )
  69. {
  70. arma_extra_debug_sigprint();
  71. return mtOp<typename std::complex<typename T1::pod_type>, T1, op_cx_scalar_times>('j', X, k);
  72. }
  73. //! scalar * trans(T1)
  74. template<typename T1>
  75. arma_inline
  76. const Op<T1, op_htrans2>
  77. operator*
  78. (const typename T1::elem_type k, const Op<T1, op_htrans>& X)
  79. {
  80. arma_extra_debug_sigprint();
  81. return Op<T1, op_htrans2>(X.m, k);
  82. }
  83. //! trans(T1) * scalar
  84. template<typename T1>
  85. arma_inline
  86. const Op<T1, op_htrans2>
  87. operator*
  88. (const Op<T1, op_htrans>& X, const typename T1::elem_type k)
  89. {
  90. arma_extra_debug_sigprint();
  91. return Op<T1, op_htrans2>(X.m, k);
  92. }
  93. //! Base * diagmat
  94. template<typename T1, typename T2>
  95. arma_inline
  96. typename
  97. enable_if2
  98. <
  99. (is_arma_type<T1>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
  100. const Glue<T1, Op<T2, op_diagmat>, glue_times_diag>
  101. >::result
  102. operator*
  103. (const T1& X, const Op<T2, op_diagmat>& Y)
  104. {
  105. arma_extra_debug_sigprint();
  106. return Glue<T1, Op<T2, op_diagmat>, glue_times_diag>(X, Y);
  107. }
  108. //! diagmat * Base
  109. template<typename T1, typename T2>
  110. arma_inline
  111. typename
  112. enable_if2
  113. <
  114. (is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
  115. const Glue<Op<T1, op_diagmat>, T2, glue_times_diag>
  116. >::result
  117. operator*
  118. (const Op<T1, op_diagmat>& X, const T2& Y)
  119. {
  120. arma_extra_debug_sigprint();
  121. return Glue<Op<T1, op_diagmat>, T2, glue_times_diag>(X, Y);
  122. }
  123. //! diagmat * diagmat
  124. template<typename T1, typename T2>
  125. inline
  126. Mat< typename promote_type<typename T1::elem_type, typename T2::elem_type>::result >
  127. operator*
  128. (const Op<T1, op_diagmat>& X, const Op<T2, op_diagmat>& Y)
  129. {
  130. arma_extra_debug_sigprint();
  131. typedef typename T1::elem_type eT1;
  132. typedef typename T2::elem_type eT2;
  133. typedef typename promote_type<eT1,eT2>::result out_eT;
  134. promote_type<eT1,eT2>::check();
  135. const diagmat_proxy<T1> A(X.m);
  136. const diagmat_proxy<T2> B(Y.m);
  137. arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
  138. Mat<out_eT> out(A.n_rows, B.n_cols, fill::zeros);
  139. const uword A_length = (std::min)(A.n_rows, A.n_cols);
  140. const uword B_length = (std::min)(B.n_rows, B.n_cols);
  141. const uword N = (std::min)(A_length, B_length);
  142. for(uword i=0; i<N; ++i)
  143. {
  144. out.at(i,i) = upgrade_val<eT1,eT2>::apply( A[i] ) * upgrade_val<eT1,eT2>::apply( B[i] );
  145. }
  146. return out;
  147. }
  148. //! multiplication of Base objects with same element type
  149. template<typename T1, typename T2>
  150. arma_inline
  151. typename
  152. enable_if2
  153. <
  154. is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value,
  155. const Glue<T1, T2, glue_times>
  156. >::result
  157. operator*
  158. (const T1& X, const T2& Y)
  159. {
  160. arma_extra_debug_sigprint();
  161. return Glue<T1, T2, glue_times>(X, Y);
  162. }
  163. //! multiplication of Base objects with different element types
  164. template<typename T1, typename T2>
  165. inline
  166. typename
  167. enable_if2
  168. <
  169. (is_arma_type<T1>::value && is_arma_type<T2>::value && (is_same_type<typename T1::elem_type, typename T2::elem_type>::no)),
  170. const mtGlue< typename promote_type<typename T1::elem_type, typename T2::elem_type>::result, T1, T2, glue_mixed_times >
  171. >::result
  172. operator*
  173. (
  174. const T1& X,
  175. const T2& Y
  176. )
  177. {
  178. arma_extra_debug_sigprint();
  179. typedef typename T1::elem_type eT1;
  180. typedef typename T2::elem_type eT2;
  181. typedef typename promote_type<eT1,eT2>::result out_eT;
  182. promote_type<eT1,eT2>::check();
  183. return mtGlue<out_eT, T1, T2, glue_mixed_times>( X, Y );
  184. }
  185. //! sparse multiplied by scalar
  186. template<typename T1>
  187. inline
  188. typename
  189. enable_if2
  190. <
  191. is_arma_sparse_type<T1>::value,
  192. SpOp<T1,spop_scalar_times>
  193. >::result
  194. operator*
  195. (
  196. const T1& X,
  197. const typename T1::elem_type k
  198. )
  199. {
  200. arma_extra_debug_sigprint();
  201. return SpOp<T1,spop_scalar_times>(X, k);
  202. }
  203. template<typename T1>
  204. inline
  205. typename
  206. enable_if2
  207. <
  208. is_arma_sparse_type<T1>::value,
  209. SpOp<T1,spop_scalar_times>
  210. >::result
  211. operator*
  212. (
  213. const typename T1::elem_type k,
  214. const T1& X
  215. )
  216. {
  217. arma_extra_debug_sigprint();
  218. return SpOp<T1,spop_scalar_times>(X, k);
  219. }
  220. //! non-complex sparse * complex scalar
  221. template<typename T1>
  222. arma_inline
  223. typename
  224. enable_if2
  225. <
  226. (is_arma_sparse_type<T1>::value && is_cx<typename T1::elem_type>::no),
  227. const mtSpOp<typename std::complex<typename T1::pod_type>, T1, spop_cx_scalar_times>
  228. >::result
  229. operator*
  230. (
  231. const T1& X,
  232. const std::complex<typename T1::pod_type>& k
  233. )
  234. {
  235. arma_extra_debug_sigprint();
  236. return mtSpOp<typename std::complex<typename T1::pod_type>, T1, spop_cx_scalar_times>('j', X, k);
  237. }
  238. //! complex scalar * non-complex sparse
  239. template<typename T1>
  240. arma_inline
  241. typename
  242. enable_if2
  243. <
  244. (is_arma_sparse_type<T1>::value && is_cx<typename T1::elem_type>::no),
  245. const mtSpOp<typename std::complex<typename T1::pod_type>, T1, spop_cx_scalar_times>
  246. >::result
  247. operator*
  248. (
  249. const std::complex<typename T1::pod_type>& k,
  250. const T1& X
  251. )
  252. {
  253. arma_extra_debug_sigprint();
  254. return mtSpOp<typename std::complex<typename T1::pod_type>, T1, spop_cx_scalar_times>('j', X, k);
  255. }
  256. //! multiplication of two sparse objects
  257. template<typename T1, typename T2>
  258. inline
  259. typename
  260. enable_if2
  261. <
  262. (is_arma_sparse_type<T1>::value && is_arma_sparse_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
  263. const SpGlue<T1,T2,spglue_times>
  264. >::result
  265. operator*
  266. (
  267. const T1& x,
  268. const T2& y
  269. )
  270. {
  271. arma_extra_debug_sigprint();
  272. return SpGlue<T1,T2,spglue_times>(x, y);
  273. }
  274. //! multiplication of one sparse and one dense object
  275. template<typename T1, typename T2>
  276. inline
  277. typename
  278. enable_if2
  279. <
  280. (is_arma_sparse_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
  281. Mat<typename T1::elem_type>
  282. >::result
  283. operator*
  284. (
  285. const T1& x,
  286. const T2& y
  287. )
  288. {
  289. arma_extra_debug_sigprint();
  290. typedef typename T1::elem_type eT;
  291. Mat<eT> result;
  292. spglue_times_misc::sparse_times_dense(result, x, y);
  293. return result;
  294. }
  295. //! multiplication of one dense and one sparse object
  296. template<typename T1, typename T2>
  297. inline
  298. typename
  299. enable_if2
  300. <
  301. (is_arma_type<T1>::value && is_arma_sparse_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
  302. Mat<typename T1::elem_type>
  303. >::result
  304. operator*
  305. (
  306. const T1& x,
  307. const T2& y
  308. )
  309. {
  310. arma_extra_debug_sigprint();
  311. typedef typename T1::elem_type eT;
  312. Mat<eT> result;
  313. spglue_times_misc::dense_times_sparse(result, x, y);
  314. return result;
  315. }
  316. //! multiplication of two sparse objects with different element types
  317. template<typename T1, typename T2>
  318. inline
  319. typename
  320. enable_if2
  321. <
  322. (is_arma_sparse_type<T1>::value && is_arma_sparse_type<T2>::value && (is_same_type<typename T1::elem_type, typename T2::elem_type>::no)),
  323. const mtSpGlue< typename promote_type<typename T1::elem_type, typename T2::elem_type>::result, T1, T2, spglue_times_mixed >
  324. >::result
  325. operator*
  326. (
  327. const T1& X,
  328. const T2& Y
  329. )
  330. {
  331. arma_extra_debug_sigprint();
  332. typedef typename T1::elem_type eT1;
  333. typedef typename T2::elem_type eT2;
  334. typedef typename promote_type<eT1,eT2>::result out_eT;
  335. promote_type<eT1,eT2>::check();
  336. return mtSpGlue<out_eT, T1, T2, spglue_times_mixed>( X, Y );
  337. }
  338. //! multiplication of one sparse and one dense object with different element types
  339. template<typename T1, typename T2>
  340. inline
  341. typename
  342. enable_if2
  343. <
  344. (is_arma_sparse_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::no),
  345. Mat< typename promote_type<typename T1::elem_type, typename T2::elem_type>::result >
  346. >::result
  347. operator*
  348. (
  349. const T1& X,
  350. const T2& Y
  351. )
  352. {
  353. arma_extra_debug_sigprint();
  354. Mat< typename promote_type<typename T1::elem_type, typename T2::elem_type>::result > out;
  355. spglue_times_mixed::sparse_times_dense(out, X, Y);
  356. return out;
  357. }
  358. //! multiplication of one dense and one sparse object with different element types
  359. template<typename T1, typename T2>
  360. inline
  361. typename
  362. enable_if2
  363. <
  364. (is_arma_type<T1>::value && is_arma_sparse_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::no),
  365. Mat< typename promote_type<typename T1::elem_type, typename T2::elem_type>::result >
  366. >::result
  367. operator*
  368. (
  369. const T1& X,
  370. const T2& Y
  371. )
  372. {
  373. arma_extra_debug_sigprint();
  374. Mat< typename promote_type<typename T1::elem_type, typename T2::elem_type>::result > out;
  375. spglue_times_mixed::dense_times_sparse(out, X, Y);
  376. return out;
  377. }
  378. //! @}