spglue_kron_meat.hpp 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup spglue_kron
  16. //! @{
  17. template<typename T1, typename T2>
  18. inline
  19. void
  20. spglue_kron::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_kron>& X)
  21. {
  22. arma_extra_debug_sigprint();
  23. typedef typename T1::elem_type eT;
  24. const unwrap_spmat<T1> UA(X.A);
  25. const unwrap_spmat<T2> UB(X.B);
  26. if(UA.is_alias(out) || UB.is_alias(out))
  27. {
  28. SpMat<eT> tmp;
  29. spglue_kron::apply_noalias(tmp, UA.M, UB.M);
  30. out.steal_mem(tmp);
  31. }
  32. else
  33. {
  34. spglue_kron::apply_noalias(out, UA.M, UB.M);
  35. }
  36. }
  37. template<typename eT>
  38. inline
  39. void
  40. spglue_kron::apply_noalias(SpMat<eT>& out, const SpMat<eT>& A, const SpMat<eT>& B)
  41. {
  42. arma_extra_debug_sigprint();
  43. const uword A_n_rows = A.n_rows;
  44. const uword A_n_cols = A.n_cols;
  45. const uword B_n_rows = B.n_rows;
  46. const uword B_n_cols = B.n_cols;
  47. const uword out_n_nonzero = A.n_nonzero * B.n_nonzero;
  48. out.reserve(A_n_rows * B_n_rows, A_n_cols * B_n_cols, out_n_nonzero);
  49. if(out_n_nonzero == 0) { return; }
  50. access::rw(out.col_ptrs[0]) = 0;
  51. uword count = 0;
  52. for(uword A_col=0; A_col < A_n_cols; ++A_col)
  53. for(uword B_col=0; B_col < B_n_cols; ++B_col)
  54. {
  55. for(uword A_i = A.col_ptrs[A_col]; A_i < A.col_ptrs[A_col+1]; ++A_i)
  56. {
  57. const uword out_row = A.row_indices[A_i] * B_n_rows;
  58. const eT A_val = A.values[A_i];
  59. for(uword B_i = B.col_ptrs[B_col]; B_i < B.col_ptrs[B_col+1]; ++B_i)
  60. {
  61. access::rw(out.values[count]) = A_val * B.values[B_i];
  62. access::rw(out.row_indices[count]) = out_row + B.row_indices[B_i];
  63. count++;
  64. }
  65. }
  66. access::rw(out.col_ptrs[A_col * B_n_cols + B_col + 1]) = count;
  67. }
  68. }
  69. // template<typename T1, typename T2>
  70. // inline
  71. // void
  72. // spglue_kron::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_kron>& X)
  73. // {
  74. // arma_extra_debug_sigprint();
  75. //
  76. // typedef typename T1::elem_type eT;
  77. //
  78. // const unwrap_spmat<T1> UA(X.A);
  79. // const unwrap_spmat<T2> UB(X.B);
  80. //
  81. // const SpMat<eT>& A = UA.M;
  82. // const SpMat<eT>& B = UB.M;
  83. //
  84. // umat locs(2, A.n_nonzero * B.n_nonzero);
  85. // Col<eT> vals( A.n_nonzero * B.n_nonzero);
  86. //
  87. // uword* locs_mem = locs.memptr();
  88. // eT* vals_mem = vals.memptr();
  89. //
  90. // typename SpMat<eT>::const_iterator A_it = A.begin();
  91. // typename SpMat<eT>::const_iterator A_it_end = A.end();
  92. //
  93. // typename SpMat<eT>::const_iterator B_it_start = B.begin();
  94. // typename SpMat<eT>::const_iterator B_it_end = B.end();
  95. //
  96. // const uword B_n_rows = B.n_rows;
  97. // const uword B_n_cols = B.n_cols;
  98. //
  99. // uword i = 0;
  100. //
  101. // while(A_it != A_it_end)
  102. // {
  103. // typename SpMat<eT>::const_iterator B_it = B_it_start;
  104. //
  105. // const uword loc_row = A_it.row() * B_n_rows;
  106. // const uword loc_col = A_it.col() * B_n_cols;
  107. //
  108. // const eT A_val = (*A_it);
  109. //
  110. // while(B_it != B_it_end)
  111. // {
  112. // (*locs_mem) = loc_row + B_it.row(); locs_mem++;
  113. // (*locs_mem) = loc_col + B_it.col(); locs_mem++;
  114. //
  115. // vals_mem[i] = A_val * (*B_it);
  116. //
  117. // ++i;
  118. // ++B_it;
  119. // }
  120. //
  121. // ++A_it;
  122. // }
  123. //
  124. // out = SpMat<eT>(locs, vals, A.n_rows*B.n_rows, A.n_cols*B.n_cols);
  125. // }
  126. //! @}