spglue_schur_meat.hpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup spglue_schur
  16. //! @{
  17. template<typename T1, typename T2>
  18. arma_hot
  19. inline
  20. void
  21. spglue_schur::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_schur>& X)
  22. {
  23. arma_extra_debug_sigprint();
  24. typedef typename T1::elem_type eT;
  25. const SpProxy<T1> pa(X.A);
  26. const SpProxy<T2> pb(X.B);
  27. const bool is_alias = pa.is_alias(out) || pb.is_alias(out);
  28. if(is_alias == false)
  29. {
  30. spglue_schur::apply_noalias(out, pa, pb);
  31. }
  32. else
  33. {
  34. SpMat<eT> tmp;
  35. spglue_schur::apply_noalias(tmp, pa, pb);
  36. out.steal_mem(tmp);
  37. }
  38. }
  39. template<typename eT, typename T1, typename T2>
  40. arma_hot
  41. inline
  42. void
  43. spglue_schur::apply_noalias(SpMat<eT>& out, const SpProxy<T1>& pa, const SpProxy<T2>& pb)
  44. {
  45. arma_extra_debug_sigprint();
  46. arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication");
  47. if( (pa.get_n_nonzero() == 0) || (pb.get_n_nonzero() == 0) )
  48. {
  49. out.zeros(pa.get_n_rows(), pa.get_n_cols());
  50. return;
  51. }
  52. const uword max_n_nonzero = spglue_elem_helper::max_n_nonzero_schur(pa, pb);
  53. // Resize memory to upper bound
  54. out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero);
  55. // Now iterate across both matrices.
  56. typename SpProxy<T1>::const_iterator_type x_it = pa.begin();
  57. typename SpProxy<T1>::const_iterator_type x_end = pa.end();
  58. typename SpProxy<T2>::const_iterator_type y_it = pb.begin();
  59. typename SpProxy<T2>::const_iterator_type y_end = pb.end();
  60. uword count = 0;
  61. while( (x_it != x_end) || (y_it != y_end) )
  62. {
  63. const uword x_it_row = x_it.row();
  64. const uword x_it_col = x_it.col();
  65. const uword y_it_row = y_it.row();
  66. const uword y_it_col = y_it.col();
  67. if(x_it == y_it)
  68. {
  69. const eT out_val = (*x_it) * (*y_it);
  70. if(out_val != eT(0))
  71. {
  72. access::rw(out.values[count]) = out_val;
  73. access::rw(out.row_indices[count]) = x_it_row;
  74. access::rw(out.col_ptrs[x_it_col + 1])++;
  75. ++count;
  76. }
  77. ++x_it;
  78. ++y_it;
  79. }
  80. else
  81. {
  82. if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end
  83. {
  84. ++x_it;
  85. }
  86. else
  87. {
  88. ++y_it;
  89. }
  90. }
  91. }
  92. const uword out_n_cols = out.n_cols;
  93. uword* col_ptrs = access::rwp(out.col_ptrs);
  94. // Fix column pointers to be cumulative.
  95. for(uword c = 1; c <= out_n_cols; ++c)
  96. {
  97. col_ptrs[c] += col_ptrs[c - 1];
  98. }
  99. if(count < max_n_nonzero)
  100. {
  101. if(count <= (max_n_nonzero/2))
  102. {
  103. out.mem_resize(count);
  104. }
  105. else
  106. {
  107. // quick resize without reallocating memory and copying data
  108. access::rw( out.n_nonzero) = count;
  109. access::rw( out.values[count]) = eT(0);
  110. access::rw(out.row_indices[count]) = uword(0);
  111. }
  112. }
  113. }
  114. template<typename eT>
  115. arma_hot
  116. inline
  117. void
  118. spglue_schur::apply_noalias(SpMat<eT>& out, const SpMat<eT>& A, const SpMat<eT>& B)
  119. {
  120. arma_extra_debug_sigprint();
  121. const SpProxy< SpMat<eT> > pa(A);
  122. const SpProxy< SpMat<eT> > pb(B);
  123. spglue_schur::apply_noalias(out, pa, pb);
  124. }
  125. //
  126. //
  127. //
  128. template<typename T1, typename T2>
  129. inline
  130. void
  131. spglue_schur_misc::dense_schur_sparse(SpMat<typename T1::elem_type>& out, const T1& x, const T2& y)
  132. {
  133. arma_extra_debug_sigprint();
  134. typedef typename T1::elem_type eT;
  135. const Proxy<T1> pa(x);
  136. const SpProxy<T2> pb(y);
  137. arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication");
  138. // count new size
  139. uword new_n_nonzero = 0;
  140. typename SpProxy<T2>::const_iterator_type it = pb.begin();
  141. typename SpProxy<T2>::const_iterator_type it_end = pb.end();
  142. while(it != it_end)
  143. {
  144. if( ((*it) * pa.at(it.row(), it.col())) != eT(0) ) { ++new_n_nonzero; }
  145. ++it;
  146. }
  147. // Resize memory accordingly.
  148. out.reserve(pa.get_n_rows(), pa.get_n_cols(), new_n_nonzero);
  149. uword count = 0;
  150. typename SpProxy<T2>::const_iterator_type it2 = pb.begin();
  151. while(it2 != it_end)
  152. {
  153. const uword it2_row = it2.row();
  154. const uword it2_col = it2.col();
  155. const eT val = (*it2) * pa.at(it2_row, it2_col);
  156. if(val != eT(0))
  157. {
  158. access::rw( out.values[count]) = val;
  159. access::rw( out.row_indices[count]) = it2_row;
  160. access::rw(out.col_ptrs[it2_col + 1])++;
  161. ++count;
  162. }
  163. ++it2;
  164. }
  165. // Fix column pointers.
  166. for(uword c = 1; c <= out.n_cols; ++c)
  167. {
  168. access::rw(out.col_ptrs[c]) += out.col_ptrs[c - 1];
  169. }
  170. }
  171. //
  172. template<typename T1, typename T2>
  173. inline
  174. void
  175. spglue_schur_mixed::apply(SpMat<typename eT_promoter<T1,T2>::eT>& out, const mtSpGlue<typename eT_promoter<T1,T2>::eT, T1, T2, spglue_schur_mixed>& expr)
  176. {
  177. arma_extra_debug_sigprint();
  178. typedef typename T1::elem_type eT1;
  179. typedef typename T2::elem_type eT2;
  180. typedef typename promote_type<eT1,eT2>::result out_eT;
  181. promote_type<eT1,eT2>::check();
  182. if( (is_same_type<eT1,out_eT>::no) && (is_same_type<eT2,out_eT>::yes) )
  183. {
  184. // upgrade T1
  185. const unwrap_spmat<T1> UA(expr.A);
  186. const unwrap_spmat<T2> UB(expr.B);
  187. const SpMat<eT1>& A = UA.M;
  188. const SpMat<eT2>& B = UB.M;
  189. SpMat<out_eT> AA(arma_layout_indicator(), A);
  190. for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); }
  191. const SpMat<out_eT>& BB = reinterpret_cast< const SpMat<out_eT>& >(B);
  192. out = AA % BB;
  193. }
  194. else
  195. if( (is_same_type<eT1,out_eT>::yes) && (is_same_type<eT2,out_eT>::no) )
  196. {
  197. // upgrade T2
  198. const unwrap_spmat<T1> UA(expr.A);
  199. const unwrap_spmat<T2> UB(expr.B);
  200. const SpMat<eT1>& A = UA.M;
  201. const SpMat<eT2>& B = UB.M;
  202. const SpMat<out_eT>& AA = reinterpret_cast< const SpMat<out_eT>& >(A);
  203. SpMat<out_eT> BB(arma_layout_indicator(), B);
  204. for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); }
  205. out = AA % BB;
  206. }
  207. else
  208. {
  209. // upgrade T1 and T2
  210. const unwrap_spmat<T1> UA(expr.A);
  211. const unwrap_spmat<T2> UB(expr.B);
  212. const SpMat<eT1>& A = UA.M;
  213. const SpMat<eT2>& B = UB.M;
  214. SpMat<out_eT> AA(arma_layout_indicator(), A);
  215. SpMat<out_eT> BB(arma_layout_indicator(), B);
  216. for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); }
  217. for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); }
  218. out = AA % BB;
  219. }
  220. }
  221. template<typename T1, typename T2>
  222. inline
  223. void
  224. spglue_schur_mixed::dense_schur_sparse(SpMat< typename promote_type<typename T1::elem_type, typename T2::elem_type >::result>& out, const T1& X, const T2& Y)
  225. {
  226. arma_extra_debug_sigprint();
  227. typedef typename T1::elem_type eT1;
  228. typedef typename T2::elem_type eT2;
  229. typedef typename promote_type<eT1,eT2>::result out_eT;
  230. promote_type<eT1,eT2>::check();
  231. const Proxy<T1> pa(X);
  232. const SpProxy<T2> pb(Y);
  233. arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication");
  234. // count new size
  235. uword new_n_nonzero = 0;
  236. typename SpProxy<T2>::const_iterator_type it = pb.begin();
  237. typename SpProxy<T2>::const_iterator_type it_end = pb.end();
  238. while(it != it_end)
  239. {
  240. if( (out_eT(*it) * out_eT(pa.at(it.row(), it.col()))) != out_eT(0) ) { ++new_n_nonzero; }
  241. ++it;
  242. }
  243. // Resize memory accordingly.
  244. out.reserve(pa.get_n_rows(), pa.get_n_cols(), new_n_nonzero);
  245. uword count = 0;
  246. typename SpProxy<T2>::const_iterator_type it2 = pb.begin();
  247. while(it2 != it_end)
  248. {
  249. const uword it2_row = it2.row();
  250. const uword it2_col = it2.col();
  251. const out_eT val = out_eT(*it2) * out_eT(pa.at(it2_row, it2_col));
  252. if(val != out_eT(0))
  253. {
  254. access::rw( out.values[count]) = val;
  255. access::rw( out.row_indices[count]) = it2_row;
  256. access::rw(out.col_ptrs[it2_col + 1])++;
  257. ++count;
  258. }
  259. ++it2;
  260. }
  261. // Fix column pointers.
  262. for(uword c = 1; c <= out.n_cols; ++c)
  263. {
  264. access::rw(out.col_ptrs[c]) += out.col_ptrs[c - 1];
  265. }
  266. }
  267. //! @}