fn_dot.hpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup fn_dot
  16. //! @{
  17. template<typename T1, typename T2>
  18. arma_warn_unused
  19. arma_inline
  20. typename
  21. enable_if2
  22. <
  23. is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::yes,
  24. typename T1::elem_type
  25. >::result
  26. dot
  27. (
  28. const T1& A,
  29. const T2& B
  30. )
  31. {
  32. arma_extra_debug_sigprint();
  33. return op_dot::apply(A,B);
  34. }
  35. template<typename T1, typename T2>
  36. arma_warn_unused
  37. inline
  38. typename
  39. enable_if2
  40. <
  41. is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::no,
  42. typename promote_type<typename T1::elem_type, typename T2::elem_type>::result
  43. >::result
  44. dot
  45. (
  46. const T1& A,
  47. const T2& B
  48. )
  49. {
  50. arma_extra_debug_sigprint();
  51. return op_dot_mixed::apply(A,B);
  52. }
  53. template<typename T1, typename T2>
  54. arma_warn_unused
  55. inline
  56. typename
  57. enable_if2
  58. <
  59. is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value,
  60. typename T1::elem_type
  61. >::result
  62. norm_dot
  63. (
  64. const T1& A,
  65. const T2& B
  66. )
  67. {
  68. arma_extra_debug_sigprint();
  69. return op_norm_dot::apply(A,B);
  70. }
  71. //
  72. // cdot
  73. template<typename T1, typename T2>
  74. arma_warn_unused
  75. arma_inline
  76. typename
  77. enable_if2
  78. <
  79. is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value && is_cx<typename T1::elem_type>::no,
  80. typename T1::elem_type
  81. >::result
  82. cdot
  83. (
  84. const T1& A,
  85. const T2& B
  86. )
  87. {
  88. arma_extra_debug_sigprint();
  89. return op_dot::apply(A,B);
  90. }
  91. template<typename T1, typename T2>
  92. arma_warn_unused
  93. arma_inline
  94. typename
  95. enable_if2
  96. <
  97. is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value && is_cx<typename T1::elem_type>::yes,
  98. typename T1::elem_type
  99. >::result
  100. cdot
  101. (
  102. const T1& A,
  103. const T2& B
  104. )
  105. {
  106. arma_extra_debug_sigprint();
  107. return op_cdot::apply(A,B);
  108. }
  109. // convert dot(htrans(x), y) to cdot(x,y)
  110. template<typename T1, typename T2>
  111. arma_warn_unused
  112. arma_inline
  113. typename
  114. enable_if2
  115. <
  116. is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value && is_cx<typename T1::elem_type>::yes,
  117. typename T1::elem_type
  118. >::result
  119. dot
  120. (
  121. const Op<T1, op_htrans>& A,
  122. const T2& B
  123. )
  124. {
  125. arma_extra_debug_sigprint();
  126. return cdot(A.m, B);
  127. }
  128. //
  129. // for sparse matrices
  130. //
  131. namespace priv
  132. {
  133. template<typename T1, typename T2>
  134. arma_hot
  135. inline
  136. typename T1::elem_type
  137. dot_helper(const SpProxy<T1>& pa, const SpProxy<T2>& pb)
  138. {
  139. typedef typename T1::elem_type eT;
  140. // Iterate over both objects and see when they are the same
  141. eT result = eT(0);
  142. typename SpProxy<T1>::const_iterator_type a_it = pa.begin();
  143. typename SpProxy<T1>::const_iterator_type a_end = pa.end();
  144. typename SpProxy<T2>::const_iterator_type b_it = pb.begin();
  145. typename SpProxy<T2>::const_iterator_type b_end = pb.end();
  146. while((a_it != a_end) && (b_it != b_end))
  147. {
  148. if(a_it == b_it)
  149. {
  150. result += (*a_it) * (*b_it);
  151. ++a_it;
  152. ++b_it;
  153. }
  154. else if((a_it.col() < b_it.col()) || ((a_it.col() == b_it.col()) && (a_it.row() < b_it.row())))
  155. {
  156. // a_it is "behind"
  157. ++a_it;
  158. }
  159. else
  160. {
  161. // b_it is "behind"
  162. ++b_it;
  163. }
  164. }
  165. return result;
  166. }
  167. }
  168. //! dot product of two sparse objects
  169. template<typename T1, typename T2>
  170. arma_warn_unused
  171. arma_hot
  172. inline
  173. typename
  174. enable_if2
  175. <(is_arma_sparse_type<T1>::value) && (is_arma_sparse_type<T2>::value) && (is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
  176. typename T1::elem_type
  177. >::result
  178. dot
  179. (
  180. const T1& x,
  181. const T2& y
  182. )
  183. {
  184. arma_extra_debug_sigprint();
  185. const SpProxy<T1> pa(x);
  186. const SpProxy<T2> pb(y);
  187. arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "dot()");
  188. typedef typename T1::elem_type eT;
  189. typedef typename SpProxy<T1>::stored_type pa_Q_type;
  190. typedef typename SpProxy<T2>::stored_type pb_Q_type;
  191. if(
  192. ( (SpProxy<T1>::use_iterator == false) && (SpProxy<T2>::use_iterator == false) )
  193. && ( (is_SpMat<pa_Q_type>::value == true ) && (is_SpMat<pb_Q_type>::value == true ) )
  194. )
  195. {
  196. const unwrap_spmat<pa_Q_type> tmp_a(pa.Q);
  197. const unwrap_spmat<pb_Q_type> tmp_b(pb.Q);
  198. const SpMat<eT>& A = tmp_a.M;
  199. const SpMat<eT>& B = tmp_b.M;
  200. if( &A == &B )
  201. {
  202. // We can do it directly!
  203. return op_dot::direct_dot_arma(A.n_nonzero, A.values, A.values);
  204. }
  205. else
  206. {
  207. return priv::dot_helper(pa,pb);
  208. }
  209. }
  210. else
  211. {
  212. return priv::dot_helper(pa,pb);
  213. }
  214. }
  215. //! dot product of one dense and one sparse object
  216. template<typename T1, typename T2>
  217. arma_warn_unused
  218. arma_hot
  219. inline
  220. typename
  221. enable_if2
  222. <(is_arma_type<T1>::value) && (is_arma_sparse_type<T2>::value) && (is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
  223. typename T1::elem_type
  224. >::result
  225. dot
  226. (
  227. const T1& x,
  228. const T2& y
  229. )
  230. {
  231. arma_extra_debug_sigprint();
  232. const Proxy<T1> pa(x);
  233. const SpProxy<T2> pb(y);
  234. arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "dot()");
  235. typedef typename T1::elem_type eT;
  236. eT result = eT(0);
  237. typename SpProxy<T2>::const_iterator_type it = pb.begin();
  238. typename SpProxy<T2>::const_iterator_type it_end = pb.end();
  239. // use_at == false won't save us operations
  240. while(it != it_end)
  241. {
  242. result += (*it) * pa.at(it.row(), it.col());
  243. ++it;
  244. }
  245. return result;
  246. }
  247. //! dot product of one sparse and one dense object
  248. template<typename T1, typename T2>
  249. arma_warn_unused
  250. arma_hot
  251. inline
  252. typename
  253. enable_if2
  254. <(is_arma_sparse_type<T1>::value) && (is_arma_type<T2>::value) && (is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
  255. typename T1::elem_type
  256. >::result
  257. dot
  258. (
  259. const T1& x,
  260. const T2& y
  261. )
  262. {
  263. arma_extra_debug_sigprint();
  264. // this is commutative
  265. return dot(y, x);
  266. }
  267. //! @}