spglue_relational_meat.hpp 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup spglue_relational
  16. //! @{
  17. template<typename T1, typename T2>
  18. inline
  19. void
  20. spglue_rel_lt::apply(SpMat<uword>& out, const mtSpGlue<uword, T1, T2, spglue_rel_lt>& X)
  21. {
  22. arma_extra_debug_sigprint();
  23. const SpProxy<T1> PA(X.A);
  24. const SpProxy<T2> PB(X.B);
  25. const bool is_alias = PA.is_alias(out) || PB.is_alias(out);
  26. if(is_alias == false)
  27. {
  28. spglue_rel_lt::apply_noalias(out, PA, PB);
  29. }
  30. else
  31. {
  32. SpMat<uword> tmp;
  33. spglue_rel_lt::apply_noalias(tmp, PA, PB);
  34. out.steal_mem(tmp);
  35. }
  36. }
  37. template<typename T1, typename T2>
  38. inline
  39. void
  40. spglue_rel_lt::apply_noalias(SpMat<uword>& out, const SpProxy<T1>& PA, const SpProxy<T2>& PB)
  41. {
  42. arma_extra_debug_sigprint();
  43. typedef typename T1::elem_type eT;
  44. arma_debug_assert_same_size(PA.get_n_rows(), PA.get_n_cols(), PB.get_n_rows(), PB.get_n_cols(), "operator<");
  45. const uword max_n_nonzero = spglue_elem_helper::max_n_nonzero_plus(PA, PB);
  46. // Resize memory to upper bound
  47. out.reserve(PA.get_n_rows(), PA.get_n_cols(), max_n_nonzero);
  48. // Now iterate across both matrices.
  49. typename SpProxy<T1>::const_iterator_type x_it = PA.begin();
  50. typename SpProxy<T1>::const_iterator_type x_end = PA.end();
  51. typename SpProxy<T2>::const_iterator_type y_it = PB.begin();
  52. typename SpProxy<T2>::const_iterator_type y_end = PB.end();
  53. uword count = 0;
  54. while( (x_it != x_end) || (y_it != y_end) )
  55. {
  56. uword out_val;
  57. const uword x_it_col = x_it.col();
  58. const uword x_it_row = x_it.row();
  59. const uword y_it_col = y_it.col();
  60. const uword y_it_row = y_it.row();
  61. bool use_y_loc = false;
  62. if(x_it == y_it)
  63. {
  64. out_val = ((*x_it) < (*y_it)) ? uword(1) : uword(0);
  65. ++x_it;
  66. ++y_it;
  67. }
  68. else
  69. {
  70. if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end
  71. {
  72. out_val = ((*x_it) < eT(0)) ? uword(1) : uword(0);
  73. ++x_it;
  74. }
  75. else
  76. {
  77. out_val = (eT(0) < (*y_it)) ? uword(1) : uword(0);
  78. ++y_it;
  79. use_y_loc = true;
  80. }
  81. }
  82. if(out_val != uword(0))
  83. {
  84. access::rw(out.values[count]) = out_val;
  85. const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row;
  86. const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col;
  87. access::rw(out.row_indices[count]) = out_row;
  88. access::rw(out.col_ptrs[out_col + 1])++;
  89. ++count;
  90. }
  91. }
  92. const uword out_n_cols = out.n_cols;
  93. uword* col_ptrs = access::rwp(out.col_ptrs);
  94. // Fix column pointers to be cumulative.
  95. for(uword c = 1; c <= out_n_cols; ++c)
  96. {
  97. col_ptrs[c] += col_ptrs[c - 1];
  98. }
  99. if(count < max_n_nonzero)
  100. {
  101. if(count <= (max_n_nonzero/2))
  102. {
  103. out.mem_resize(count);
  104. }
  105. else
  106. {
  107. // quick resize without reallocating memory and copying data
  108. access::rw( out.n_nonzero) = count;
  109. access::rw( out.values[count]) = eT(0);
  110. access::rw(out.row_indices[count]) = uword(0);
  111. }
  112. }
  113. }
  114. template<typename T1, typename T2>
  115. inline
  116. void
  117. spglue_rel_gt::apply(SpMat<uword>& out, const mtSpGlue<uword, T1, T2, spglue_rel_gt>& X)
  118. {
  119. arma_extra_debug_sigprint();
  120. const SpProxy<T1> PA(X.A);
  121. const SpProxy<T2> PB(X.B);
  122. const bool is_alias = PA.is_alias(out) || PB.is_alias(out);
  123. if(is_alias == false)
  124. {
  125. spglue_rel_gt::apply_noalias(out, PA, PB);
  126. }
  127. else
  128. {
  129. SpMat<uword> tmp;
  130. spglue_rel_gt::apply_noalias(tmp, PA, PB);
  131. out.steal_mem(tmp);
  132. }
  133. }
  134. template<typename T1, typename T2>
  135. inline
  136. void
  137. spglue_rel_gt::apply_noalias(SpMat<uword>& out, const SpProxy<T1>& PA, const SpProxy<T2>& PB)
  138. {
  139. arma_extra_debug_sigprint();
  140. typedef typename T1::elem_type eT;
  141. arma_debug_assert_same_size(PA.get_n_rows(), PA.get_n_cols(), PB.get_n_rows(), PB.get_n_cols(), "operator>");
  142. const uword max_n_nonzero = spglue_elem_helper::max_n_nonzero_plus(PA, PB);
  143. // Resize memory to upper bound
  144. out.reserve(PA.get_n_rows(), PA.get_n_cols(), max_n_nonzero);
  145. // Now iterate across both matrices.
  146. typename SpProxy<T1>::const_iterator_type x_it = PA.begin();
  147. typename SpProxy<T1>::const_iterator_type x_end = PA.end();
  148. typename SpProxy<T2>::const_iterator_type y_it = PB.begin();
  149. typename SpProxy<T2>::const_iterator_type y_end = PB.end();
  150. uword count = 0;
  151. while( (x_it != x_end) || (y_it != y_end) )
  152. {
  153. uword out_val;
  154. const uword x_it_col = x_it.col();
  155. const uword x_it_row = x_it.row();
  156. const uword y_it_col = y_it.col();
  157. const uword y_it_row = y_it.row();
  158. bool use_y_loc = false;
  159. if(x_it == y_it)
  160. {
  161. out_val = ((*x_it) > (*y_it)) ? uword(1) : uword(0);
  162. ++x_it;
  163. ++y_it;
  164. }
  165. else
  166. {
  167. if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end
  168. {
  169. out_val = ((*x_it) > eT(0)) ? uword(1) : uword(0);
  170. ++x_it;
  171. }
  172. else
  173. {
  174. out_val = (eT(0) > (*y_it)) ? uword(1) : uword(0);
  175. ++y_it;
  176. use_y_loc = true;
  177. }
  178. }
  179. if(out_val != uword(0))
  180. {
  181. access::rw(out.values[count]) = out_val;
  182. const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row;
  183. const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col;
  184. access::rw(out.row_indices[count]) = out_row;
  185. access::rw(out.col_ptrs[out_col + 1])++;
  186. ++count;
  187. }
  188. }
  189. const uword out_n_cols = out.n_cols;
  190. uword* col_ptrs = access::rwp(out.col_ptrs);
  191. // Fix column pointers to be cumulative.
  192. for(uword c = 1; c <= out_n_cols; ++c)
  193. {
  194. col_ptrs[c] += col_ptrs[c - 1];
  195. }
  196. if(count < max_n_nonzero)
  197. {
  198. if(count <= (max_n_nonzero/2))
  199. {
  200. out.mem_resize(count);
  201. }
  202. else
  203. {
  204. // quick resize without reallocating memory and copying data
  205. access::rw( out.n_nonzero) = count;
  206. access::rw( out.values[count]) = eT(0);
  207. access::rw(out.row_indices[count]) = uword(0);
  208. }
  209. }
  210. }
  211. //! @}