spglue_minus_meat.hpp 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup spglue_minus
  16. //! @{
  17. template<typename T1, typename T2>
  18. arma_hot
  19. inline
  20. void
  21. spglue_minus::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_minus>& X)
  22. {
  23. arma_extra_debug_sigprint();
  24. typedef typename T1::elem_type eT;
  25. const SpProxy<T1> pa(X.A);
  26. const SpProxy<T2> pb(X.B);
  27. const bool is_alias = pa.is_alias(out) || pb.is_alias(out);
  28. if(is_alias == false)
  29. {
  30. spglue_minus::apply_noalias(out, pa, pb);
  31. }
  32. else
  33. {
  34. SpMat<eT> tmp;
  35. spglue_minus::apply_noalias(tmp, pa, pb);
  36. out.steal_mem(tmp);
  37. }
  38. }
  39. template<typename eT, typename T1, typename T2>
  40. arma_hot
  41. inline
  42. void
  43. spglue_minus::apply_noalias(SpMat<eT>& out, const SpProxy<T1>& pa, const SpProxy<T2>& pb)
  44. {
  45. arma_extra_debug_sigprint();
  46. arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "subtraction");
  47. if(pa.get_n_nonzero() == 0) { out = pb.Q; out *= eT(-1); return; }
  48. if(pb.get_n_nonzero() == 0) { out = pa.Q; return; }
  49. const uword max_n_nonzero = spglue_elem_helper::max_n_nonzero_plus(pa, pb);
  50. // Resize memory to upper bound
  51. out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero);
  52. // Now iterate across both matrices.
  53. typename SpProxy<T1>::const_iterator_type x_it = pa.begin();
  54. typename SpProxy<T1>::const_iterator_type x_end = pa.end();
  55. typename SpProxy<T2>::const_iterator_type y_it = pb.begin();
  56. typename SpProxy<T2>::const_iterator_type y_end = pb.end();
  57. uword count = 0;
  58. while( (x_it != x_end) || (y_it != y_end) )
  59. {
  60. eT out_val;
  61. const uword x_it_row = x_it.row();
  62. const uword x_it_col = x_it.col();
  63. const uword y_it_row = y_it.row();
  64. const uword y_it_col = y_it.col();
  65. bool use_y_loc = false;
  66. if(x_it == y_it)
  67. {
  68. out_val = (*x_it) - (*y_it);
  69. ++x_it;
  70. ++y_it;
  71. }
  72. else
  73. {
  74. if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end
  75. {
  76. out_val = (*x_it);
  77. ++x_it;
  78. }
  79. else
  80. {
  81. out_val = -(*y_it); // take the negative
  82. ++y_it;
  83. use_y_loc = true;
  84. }
  85. }
  86. if(out_val != eT(0))
  87. {
  88. access::rw(out.values[count]) = out_val;
  89. const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row;
  90. const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col;
  91. access::rw(out.row_indices[count]) = out_row;
  92. access::rw(out.col_ptrs[out_col + 1])++;
  93. ++count;
  94. }
  95. }
  96. const uword out_n_cols = out.n_cols;
  97. uword* col_ptrs = access::rwp(out.col_ptrs);
  98. // Fix column pointers to be cumulative.
  99. for(uword c = 1; c <= out_n_cols; ++c)
  100. {
  101. col_ptrs[c] += col_ptrs[c - 1];
  102. }
  103. if(count < max_n_nonzero)
  104. {
  105. if(count <= (max_n_nonzero/2))
  106. {
  107. out.mem_resize(count);
  108. }
  109. else
  110. {
  111. // quick resize without reallocating memory and copying data
  112. access::rw( out.n_nonzero) = count;
  113. access::rw( out.values[count]) = eT(0);
  114. access::rw(out.row_indices[count]) = uword(0);
  115. }
  116. }
  117. }
  118. template<typename eT>
  119. arma_hot
  120. inline
  121. void
  122. spglue_minus::apply_noalias(SpMat<eT>& out, const SpMat<eT>& A, const SpMat<eT>& B)
  123. {
  124. arma_extra_debug_sigprint();
  125. const SpProxy< SpMat<eT> > pa(A);
  126. const SpProxy< SpMat<eT> > pb(B);
  127. spglue_minus::apply_noalias(out, pa, pb);
  128. }
  129. //
  130. template<typename T1, typename T2>
  131. inline
  132. void
  133. spglue_minus_mixed::apply(SpMat<typename eT_promoter<T1,T2>::eT>& out, const mtSpGlue<typename eT_promoter<T1,T2>::eT, T1, T2, spglue_minus_mixed>& expr)
  134. {
  135. arma_extra_debug_sigprint();
  136. typedef typename T1::elem_type eT1;
  137. typedef typename T2::elem_type eT2;
  138. typedef typename promote_type<eT1,eT2>::result out_eT;
  139. promote_type<eT1,eT2>::check();
  140. if( (is_same_type<eT1,out_eT>::no) && (is_same_type<eT2,out_eT>::yes) )
  141. {
  142. // upgrade T1
  143. const unwrap_spmat<T1> UA(expr.A);
  144. const unwrap_spmat<T2> UB(expr.B);
  145. const SpMat<eT1>& A = UA.M;
  146. const SpMat<eT2>& B = UB.M;
  147. SpMat<out_eT> AA(arma_layout_indicator(), A);
  148. for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); }
  149. const SpMat<out_eT>& BB = reinterpret_cast< const SpMat<out_eT>& >(B);
  150. out = AA - BB;
  151. }
  152. else
  153. if( (is_same_type<eT1,out_eT>::yes) && (is_same_type<eT2,out_eT>::no) )
  154. {
  155. // upgrade T2
  156. const unwrap_spmat<T1> UA(expr.A);
  157. const unwrap_spmat<T2> UB(expr.B);
  158. const SpMat<eT1>& A = UA.M;
  159. const SpMat<eT2>& B = UB.M;
  160. const SpMat<out_eT>& AA = reinterpret_cast< const SpMat<out_eT>& >(A);
  161. SpMat<out_eT> BB(arma_layout_indicator(), B);
  162. for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); }
  163. out = AA - BB;
  164. }
  165. else
  166. {
  167. // upgrade T1 and T2
  168. const unwrap_spmat<T1> UA(expr.A);
  169. const unwrap_spmat<T2> UB(expr.B);
  170. const SpMat<eT1>& A = UA.M;
  171. const SpMat<eT2>& B = UB.M;
  172. SpMat<out_eT> AA(arma_layout_indicator(), A);
  173. SpMat<out_eT> BB(arma_layout_indicator(), B);
  174. for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); }
  175. for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); }
  176. out = AA - BB;
  177. }
  178. }
  179. template<typename T1, typename T2>
  180. inline
  181. void
  182. spglue_minus_mixed::sparse_minus_dense(Mat< typename promote_type<typename T1::elem_type, typename T2::elem_type >::result>& out, const T1& X, const T2& Y)
  183. {
  184. arma_extra_debug_sigprint();
  185. typedef typename T1::elem_type eT1;
  186. typedef typename T2::elem_type eT2;
  187. typedef typename promote_type<eT1,eT2>::result out_eT;
  188. promote_type<eT1,eT2>::check();
  189. const quasi_unwrap<T2> UB(Y);
  190. const Mat<eT2>& B = UB.M;
  191. const uword B_n_elem = B.n_elem;
  192. const eT2* B_mem = B.memptr();
  193. out.set_size(B.n_rows, B.n_cols);
  194. out_eT* out_mem = out.memptr();
  195. for(uword i=0; i<B_n_elem; ++i)
  196. {
  197. out_mem[i] = out_eT(-B_mem[i]);
  198. }
  199. const SpProxy<T1> pa(X);
  200. arma_debug_assert_same_size( pa.get_n_rows(), pa.get_n_cols(), out.n_rows, out.n_cols, "subtraction" );
  201. typename SpProxy<T1>::const_iterator_type it = pa.begin();
  202. typename SpProxy<T1>::const_iterator_type it_end = pa.end();
  203. while(it != it_end)
  204. {
  205. out.at(it.row(), it.col()) += out_eT(*it);
  206. ++it;
  207. }
  208. }
  209. template<typename T1, typename T2>
  210. inline
  211. void
  212. spglue_minus_mixed::dense_minus_sparse(Mat< typename promote_type<typename T1::elem_type, typename T2::elem_type >::result>& out, const T1& X, const T2& Y)
  213. {
  214. arma_extra_debug_sigprint();
  215. typedef typename T1::elem_type eT1;
  216. typedef typename T2::elem_type eT2;
  217. typedef typename promote_type<eT1,eT2>::result out_eT;
  218. promote_type<eT1,eT2>::check();
  219. if(is_same_type<eT1,out_eT>::no)
  220. {
  221. out = conv_to< Mat<out_eT> >::from(X);
  222. }
  223. else
  224. {
  225. const quasi_unwrap<T1> UA(X);
  226. const Mat<eT1>& A = UA.M;
  227. out = reinterpret_cast< const Mat<out_eT>& >(A);
  228. }
  229. const SpProxy<T2> pb(Y);
  230. arma_debug_assert_same_size( out.n_rows, out.n_cols, pb.get_n_rows(), pb.get_n_cols(), "subtraction" );
  231. typename SpProxy<T2>::const_iterator_type it = pb.begin();
  232. typename SpProxy<T2>::const_iterator_type it_end = pb.end();
  233. while(it != it_end)
  234. {
  235. out.at(it.row(), it.col()) -= out_eT(*it);
  236. ++it;
  237. }
  238. }
  239. //! @}