op_dotext_meat.hpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup op_dotext
  16. //! @{
  17. template<typename eT>
  18. inline
  19. eT
  20. op_dotext::direct_rowvec_mat_colvec
  21. (
  22. const eT* A_mem,
  23. const Mat<eT>& B,
  24. const eT* C_mem
  25. )
  26. {
  27. arma_extra_debug_sigprint();
  28. const uword cost_AB = B.n_cols;
  29. const uword cost_BC = B.n_rows;
  30. if(cost_AB <= cost_BC)
  31. {
  32. podarray<eT> tmp(B.n_cols);
  33. for(uword col=0; col<B.n_cols; ++col)
  34. {
  35. const eT* B_coldata = B.colptr(col);
  36. eT val = eT(0);
  37. for(uword i=0; i<B.n_rows; ++i)
  38. {
  39. val += A_mem[i] * B_coldata[i];
  40. }
  41. tmp[col] = val;
  42. }
  43. return op_dot::direct_dot(B.n_cols, tmp.mem, C_mem);
  44. }
  45. else
  46. {
  47. podarray<eT> tmp(B.n_rows);
  48. for(uword row=0; row<B.n_rows; ++row)
  49. {
  50. eT val = eT(0);
  51. for(uword col=0; col<B.n_cols; ++col)
  52. {
  53. val += B.at(row,col) * C_mem[col];
  54. }
  55. tmp[row] = val;
  56. }
  57. return op_dot::direct_dot(B.n_rows, A_mem, tmp.mem);
  58. }
  59. }
  60. template<typename eT>
  61. inline
  62. eT
  63. op_dotext::direct_rowvec_transmat_colvec
  64. (
  65. const eT* A_mem,
  66. const Mat<eT>& B,
  67. const eT* C_mem
  68. )
  69. {
  70. arma_extra_debug_sigprint();
  71. const uword cost_AB = B.n_rows;
  72. const uword cost_BC = B.n_cols;
  73. if(cost_AB <= cost_BC)
  74. {
  75. podarray<eT> tmp(B.n_rows);
  76. for(uword row=0; row<B.n_rows; ++row)
  77. {
  78. eT val = eT(0);
  79. for(uword i=0; i<B.n_cols; ++i)
  80. {
  81. val += A_mem[i] * B.at(row,i);
  82. }
  83. tmp[row] = val;
  84. }
  85. return op_dot::direct_dot(B.n_rows, tmp.mem, C_mem);
  86. }
  87. else
  88. {
  89. podarray<eT> tmp(B.n_cols);
  90. for(uword col=0; col<B.n_cols; ++col)
  91. {
  92. const eT* B_coldata = B.colptr(col);
  93. eT val = eT(0);
  94. for(uword i=0; i<B.n_rows; ++i)
  95. {
  96. val += B_coldata[i] * C_mem[i];
  97. }
  98. tmp[col] = val;
  99. }
  100. return op_dot::direct_dot(B.n_cols, A_mem, tmp.mem);
  101. }
  102. }
  103. template<typename eT>
  104. inline
  105. eT
  106. op_dotext::direct_rowvec_diagmat_colvec
  107. (
  108. const eT* A_mem,
  109. const Mat<eT>& B,
  110. const eT* C_mem
  111. )
  112. {
  113. arma_extra_debug_sigprint();
  114. eT val = eT(0);
  115. for(uword i=0; i<B.n_rows; ++i)
  116. {
  117. val += A_mem[i] * B.at(i,i) * C_mem[i];
  118. }
  119. return val;
  120. }
  121. template<typename eT>
  122. inline
  123. eT
  124. op_dotext::direct_rowvec_invdiagmat_colvec
  125. (
  126. const eT* A_mem,
  127. const Mat<eT>& B,
  128. const eT* C_mem
  129. )
  130. {
  131. arma_extra_debug_sigprint();
  132. eT val = eT(0);
  133. for(uword i=0; i<B.n_rows; ++i)
  134. {
  135. val += (A_mem[i] * C_mem[i]) / B.at(i,i);
  136. }
  137. return val;
  138. }
  139. template<typename eT>
  140. inline
  141. eT
  142. op_dotext::direct_rowvec_invdiagvec_colvec
  143. (
  144. const eT* A_mem,
  145. const Mat<eT>& B,
  146. const eT* C_mem
  147. )
  148. {
  149. arma_extra_debug_sigprint();
  150. const eT* B_mem = B.mem;
  151. eT val = eT(0);
  152. for(uword i=0; i<B.n_elem; ++i)
  153. {
  154. val += (A_mem[i] * C_mem[i]) / B_mem[i];
  155. }
  156. return val;
  157. }
  158. //! @}