op_dot_meat.hpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup op_dot
  16. //! @{
  17. //! for two arrays, generic version for non-complex values
  18. template<typename eT>
  19. arma_hot
  20. arma_inline
  21. typename arma_not_cx<eT>::result
  22. op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B)
  23. {
  24. arma_extra_debug_sigprint();
  25. #if defined(__FINITE_MATH_ONLY__) && (__FINITE_MATH_ONLY__ > 0)
  26. {
  27. eT val = eT(0);
  28. for(uword i=0; i<n_elem; ++i)
  29. {
  30. val += A[i] * B[i];
  31. }
  32. return val;
  33. }
  34. #else
  35. {
  36. eT val1 = eT(0);
  37. eT val2 = eT(0);
  38. uword i, j;
  39. for(i=0, j=1; j<n_elem; i+=2, j+=2)
  40. {
  41. val1 += A[i] * B[i];
  42. val2 += A[j] * B[j];
  43. }
  44. if(i < n_elem)
  45. {
  46. val1 += A[i] * B[i];
  47. }
  48. return val1 + val2;
  49. }
  50. #endif
  51. }
  52. //! for two arrays, generic version for complex values
  53. template<typename eT>
  54. arma_hot
  55. inline
  56. typename arma_cx_only<eT>::result
  57. op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B)
  58. {
  59. arma_extra_debug_sigprint();
  60. typedef typename get_pod_type<eT>::result T;
  61. T val_real = T(0);
  62. T val_imag = T(0);
  63. for(uword i=0; i<n_elem; ++i)
  64. {
  65. const std::complex<T>& X = A[i];
  66. const std::complex<T>& Y = B[i];
  67. const T a = X.real();
  68. const T b = X.imag();
  69. const T c = Y.real();
  70. const T d = Y.imag();
  71. val_real += (a*c) - (b*d);
  72. val_imag += (a*d) + (b*c);
  73. }
  74. return std::complex<T>(val_real, val_imag);
  75. }
  76. //! for two arrays, float and double version
  77. template<typename eT>
  78. arma_hot
  79. inline
  80. typename arma_real_only<eT>::result
  81. op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
  82. {
  83. arma_extra_debug_sigprint();
  84. if( n_elem <= 32u )
  85. {
  86. return op_dot::direct_dot_arma(n_elem, A, B);
  87. }
  88. else
  89. {
  90. #if defined(ARMA_USE_ATLAS)
  91. {
  92. arma_extra_debug_print("atlas::cblas_dot()");
  93. return atlas::cblas_dot(n_elem, A, B);
  94. }
  95. #elif defined(ARMA_USE_BLAS)
  96. {
  97. arma_extra_debug_print("blas::dot()");
  98. return blas::dot(n_elem, A, B);
  99. }
  100. #else
  101. {
  102. return op_dot::direct_dot_arma(n_elem, A, B);
  103. }
  104. #endif
  105. }
  106. }
  107. //! for two arrays, complex version
  108. template<typename eT>
  109. inline
  110. arma_hot
  111. typename arma_cx_only<eT>::result
  112. op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
  113. {
  114. if( n_elem <= 16u )
  115. {
  116. return op_dot::direct_dot_arma(n_elem, A, B);
  117. }
  118. else
  119. {
  120. #if defined(ARMA_USE_ATLAS)
  121. {
  122. arma_extra_debug_print("atlas::cblas_cx_dot()");
  123. return atlas::cblas_cx_dot(n_elem, A, B);
  124. }
  125. #elif defined(ARMA_USE_BLAS)
  126. {
  127. arma_extra_debug_print("blas::dot()");
  128. return blas::dot(n_elem, A, B);
  129. }
  130. #else
  131. {
  132. return op_dot::direct_dot_arma(n_elem, A, B);
  133. }
  134. #endif
  135. }
  136. }
  137. //! for two arrays, integral version
  138. template<typename eT>
  139. arma_hot
  140. inline
  141. typename arma_integral_only<eT>::result
  142. op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
  143. {
  144. return op_dot::direct_dot_arma(n_elem, A, B);
  145. }
  146. //! for three arrays
  147. template<typename eT>
  148. arma_hot
  149. inline
  150. eT
  151. op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C)
  152. {
  153. arma_extra_debug_sigprint();
  154. eT val = eT(0);
  155. for(uword i=0; i<n_elem; ++i)
  156. {
  157. val += A[i] * B[i] * C[i];
  158. }
  159. return val;
  160. }
  161. template<typename T1, typename T2>
  162. arma_hot
  163. inline
  164. typename T1::elem_type
  165. op_dot::apply(const T1& X, const T2& Y)
  166. {
  167. arma_extra_debug_sigprint();
  168. const bool use_at = (Proxy<T1>::use_at) || (Proxy<T2>::use_at);
  169. const bool have_direct_mem = (quasi_unwrap<T1>::has_orig_mem) && (quasi_unwrap<T2>::has_orig_mem);
  170. if(use_at || have_direct_mem)
  171. {
  172. const quasi_unwrap<T1> A(X);
  173. const quasi_unwrap<T2> B(Y);
  174. arma_debug_check( (A.M.n_elem != B.M.n_elem), "dot(): objects must have the same number of elements" );
  175. return op_dot::direct_dot(A.M.n_elem, A.M.memptr(), B.M.memptr());
  176. }
  177. else
  178. {
  179. if(is_subview_row<T1>::value && is_subview_row<T2>::value)
  180. {
  181. typedef typename T1::elem_type eT;
  182. const subview_row<eT>& A = reinterpret_cast< const subview_row<eT>& >(X);
  183. const subview_row<eT>& B = reinterpret_cast< const subview_row<eT>& >(Y);
  184. if( (A.m.n_rows == 1) && (B.m.n_rows == 1) )
  185. {
  186. arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" );
  187. const eT* A_mem = A.m.memptr();
  188. const eT* B_mem = B.m.memptr();
  189. return op_dot::direct_dot(A.n_elem, &A_mem[A.aux_col1], &B_mem[B.aux_col1]);
  190. }
  191. }
  192. const Proxy<T1> PA(X);
  193. const Proxy<T2> PB(Y);
  194. arma_debug_check( (PA.get_n_elem() != PB.get_n_elem()), "dot(): objects must have the same number of elements" );
  195. if(is_Mat<typename Proxy<T1>::stored_type>::value && is_Mat<typename Proxy<T2>::stored_type>::value)
  196. {
  197. const quasi_unwrap<typename Proxy<T1>::stored_type> A(PA.Q);
  198. const quasi_unwrap<typename Proxy<T2>::stored_type> B(PB.Q);
  199. return op_dot::direct_dot(A.M.n_elem, A.M.memptr(), B.M.memptr());
  200. }
  201. return op_dot::apply_proxy(PA,PB);
  202. }
  203. }
  204. template<typename T1, typename T2>
  205. arma_hot
  206. inline
  207. typename arma_not_cx<typename T1::elem_type>::result
  208. op_dot::apply_proxy(const Proxy<T1>& PA, const Proxy<T2>& PB)
  209. {
  210. arma_extra_debug_sigprint();
  211. typedef typename T1::elem_type eT;
  212. typedef typename Proxy<T1>::ea_type ea_type1;
  213. typedef typename Proxy<T2>::ea_type ea_type2;
  214. const uword N = PA.get_n_elem();
  215. ea_type1 A = PA.get_ea();
  216. ea_type2 B = PB.get_ea();
  217. eT val1 = eT(0);
  218. eT val2 = eT(0);
  219. uword i,j;
  220. for(i=0, j=1; j<N; i+=2, j+=2)
  221. {
  222. val1 += A[i] * B[i];
  223. val2 += A[j] * B[j];
  224. }
  225. if(i < N)
  226. {
  227. val1 += A[i] * B[i];
  228. }
  229. return val1 + val2;
  230. }
  231. template<typename T1, typename T2>
  232. arma_hot
  233. inline
  234. typename arma_cx_only<typename T1::elem_type>::result
  235. op_dot::apply_proxy(const Proxy<T1>& PA, const Proxy<T2>& PB)
  236. {
  237. arma_extra_debug_sigprint();
  238. typedef typename T1::elem_type eT;
  239. typedef typename get_pod_type<eT>::result T;
  240. typedef typename Proxy<T1>::ea_type ea_type1;
  241. typedef typename Proxy<T2>::ea_type ea_type2;
  242. const uword N = PA.get_n_elem();
  243. ea_type1 A = PA.get_ea();
  244. ea_type2 B = PB.get_ea();
  245. T val_real = T(0);
  246. T val_imag = T(0);
  247. for(uword i=0; i<N; ++i)
  248. {
  249. const std::complex<T> xx = A[i];
  250. const std::complex<T> yy = B[i];
  251. const T a = xx.real();
  252. const T b = xx.imag();
  253. const T c = yy.real();
  254. const T d = yy.imag();
  255. val_real += (a*c) - (b*d);
  256. val_imag += (a*d) + (b*c);
  257. }
  258. return std::complex<T>(val_real, val_imag);
  259. }
  260. //
  261. // op_norm_dot
  262. template<typename T1, typename T2>
  263. arma_hot
  264. inline
  265. typename T1::elem_type
  266. op_norm_dot::apply(const T1& X, const T2& Y)
  267. {
  268. arma_extra_debug_sigprint();
  269. typedef typename T1::elem_type eT;
  270. typedef typename T1::pod_type T;
  271. const quasi_unwrap<T1> tmp1(X);
  272. const quasi_unwrap<T2> tmp2(Y);
  273. const Col<eT> A( const_cast<eT*>(tmp1.M.memptr()), tmp1.M.n_elem, false );
  274. const Col<eT> B( const_cast<eT*>(tmp2.M.memptr()), tmp2.M.n_elem, false );
  275. arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
  276. const T denom = norm(A,2) * norm(B,2);
  277. return (denom != T(0)) ? ( op_dot::apply(A,B) / denom ) : eT(0);
  278. }
  279. //
  280. // op_cdot
  281. template<typename eT>
  282. arma_hot
  283. inline
  284. eT
  285. op_cdot::direct_cdot_arma(const uword n_elem, const eT* const A, const eT* const B)
  286. {
  287. arma_extra_debug_sigprint();
  288. typedef typename get_pod_type<eT>::result T;
  289. T val_real = T(0);
  290. T val_imag = T(0);
  291. for(uword i=0; i<n_elem; ++i)
  292. {
  293. const std::complex<T>& X = A[i];
  294. const std::complex<T>& Y = B[i];
  295. const T a = X.real();
  296. const T b = X.imag();
  297. const T c = Y.real();
  298. const T d = Y.imag();
  299. val_real += (a*c) + (b*d);
  300. val_imag += (a*d) - (b*c);
  301. }
  302. return std::complex<T>(val_real, val_imag);
  303. }
  304. template<typename eT>
  305. arma_hot
  306. inline
  307. eT
  308. op_cdot::direct_cdot(const uword n_elem, const eT* const A, const eT* const B)
  309. {
  310. arma_extra_debug_sigprint();
  311. if( n_elem <= 32u )
  312. {
  313. return op_cdot::direct_cdot_arma(n_elem, A, B);
  314. }
  315. else
  316. {
  317. #if defined(ARMA_USE_BLAS)
  318. {
  319. arma_extra_debug_print("blas::gemv()");
  320. // using gemv() workaround due to compatibility issues with cdotc() and zdotc()
  321. const char trans = 'C';
  322. const blas_int m = blas_int(n_elem);
  323. const blas_int n = 1;
  324. //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1);
  325. const blas_int inc = 1;
  326. const eT alpha = eT(1);
  327. const eT beta = eT(0);
  328. eT result[2]; // paranoia: using two elements instead of one
  329. //blas::gemv(&trans, &m, &n, &alpha, A, &lda, B, &inc, &beta, &result[0], &inc);
  330. blas::gemv(&trans, &m, &n, &alpha, A, &m, B, &inc, &beta, &result[0], &inc);
  331. return result[0];
  332. }
  333. #elif defined(ARMA_USE_ATLAS)
  334. {
  335. // TODO: use dedicated atlas functions cblas_cdotc_sub() and cblas_zdotc_sub() and retune threshold
  336. return op_cdot::direct_cdot_arma(n_elem, A, B);
  337. }
  338. #else
  339. {
  340. return op_cdot::direct_cdot_arma(n_elem, A, B);
  341. }
  342. #endif
  343. }
  344. }
  345. template<typename T1, typename T2>
  346. arma_hot
  347. inline
  348. typename T1::elem_type
  349. op_cdot::apply(const T1& X, const T2& Y)
  350. {
  351. arma_extra_debug_sigprint();
  352. if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
  353. {
  354. return op_cdot::apply_unwrap(X,Y);
  355. }
  356. else
  357. {
  358. return op_cdot::apply_proxy(X,Y);
  359. }
  360. }
  361. template<typename T1, typename T2>
  362. arma_hot
  363. inline
  364. typename T1::elem_type
  365. op_cdot::apply_unwrap(const T1& X, const T2& Y)
  366. {
  367. arma_extra_debug_sigprint();
  368. typedef typename T1::elem_type eT;
  369. const unwrap<T1> tmp1(X);
  370. const unwrap<T2> tmp2(Y);
  371. const Mat<eT>& A = tmp1.M;
  372. const Mat<eT>& B = tmp2.M;
  373. arma_debug_check( (A.n_elem != B.n_elem), "cdot(): objects must have the same number of elements" );
  374. return op_cdot::direct_cdot( A.n_elem, A.mem, B.mem );
  375. }
  376. template<typename T1, typename T2>
  377. arma_hot
  378. inline
  379. typename T1::elem_type
  380. op_cdot::apply_proxy(const T1& X, const T2& Y)
  381. {
  382. arma_extra_debug_sigprint();
  383. typedef typename T1::elem_type eT;
  384. typedef typename get_pod_type<eT>::result T;
  385. typedef typename Proxy<T1>::ea_type ea_type1;
  386. typedef typename Proxy<T2>::ea_type ea_type2;
  387. const bool use_at = (Proxy<T1>::use_at) || (Proxy<T2>::use_at);
  388. if(use_at == false)
  389. {
  390. const Proxy<T1> PA(X);
  391. const Proxy<T2> PB(Y);
  392. const uword N = PA.get_n_elem();
  393. arma_debug_check( (N != PB.get_n_elem()), "cdot(): objects must have the same number of elements" );
  394. ea_type1 A = PA.get_ea();
  395. ea_type2 B = PB.get_ea();
  396. T val_real = T(0);
  397. T val_imag = T(0);
  398. for(uword i=0; i<N; ++i)
  399. {
  400. const std::complex<T> AA = A[i];
  401. const std::complex<T> BB = B[i];
  402. const T a = AA.real();
  403. const T b = AA.imag();
  404. const T c = BB.real();
  405. const T d = BB.imag();
  406. val_real += (a*c) + (b*d);
  407. val_imag += (a*d) - (b*c);
  408. }
  409. return std::complex<T>(val_real, val_imag);
  410. }
  411. else
  412. {
  413. return op_cdot::apply_unwrap( X, Y );
  414. }
  415. }
  416. template<typename T1, typename T2>
  417. arma_hot
  418. inline
  419. typename promote_type<typename T1::elem_type, typename T2::elem_type>::result
  420. op_dot_mixed::apply(const T1& A, const T2& B)
  421. {
  422. arma_extra_debug_sigprint();
  423. typedef typename T1::elem_type in_eT1;
  424. typedef typename T2::elem_type in_eT2;
  425. typedef typename promote_type<in_eT1, in_eT2>::result out_eT;
  426. const Proxy<T1> PA(A);
  427. const Proxy<T2> PB(B);
  428. const uword N = PA.get_n_elem();
  429. arma_debug_check( (N != PB.get_n_elem()), "dot(): objects must have the same number of elements" );
  430. out_eT acc = out_eT(0);
  431. for(uword i=0; i < N; ++i)
  432. {
  433. acc += upgrade_val<in_eT1,in_eT2>::apply(PA[i]) * upgrade_val<in_eT1,in_eT2>::apply(PB[i]);
  434. }
  435. return acc;
  436. }
  437. //! @}