Gen_meat.hpp 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup Gen
  16. //! @{
  17. template<typename T1, typename gen_type>
  18. arma_inline
  19. Gen<T1, gen_type>::Gen(const uword in_n_rows, const uword in_n_cols)
  20. : n_rows(in_n_rows)
  21. , n_cols(in_n_cols)
  22. {
  23. arma_extra_debug_sigprint();
  24. }
  25. template<typename T1, typename gen_type>
  26. arma_inline
  27. Gen<T1, gen_type>::~Gen()
  28. {
  29. arma_extra_debug_sigprint();
  30. }
  31. template<typename T1, typename gen_type>
  32. arma_inline
  33. typename T1::elem_type
  34. Gen<T1, gen_type>::operator[](const uword ii) const
  35. {
  36. typedef typename T1::elem_type eT;
  37. if(is_same_type<gen_type, gen_eye>::yes)
  38. {
  39. return ((ii % n_rows) == (ii / n_rows)) ? eT(1) : eT(0);
  40. }
  41. else
  42. {
  43. return (*this).generate();
  44. }
  45. }
  46. template<typename T1, typename gen_type>
  47. arma_inline
  48. typename T1::elem_type
  49. Gen<T1, gen_type>::at(const uword row, const uword col) const
  50. {
  51. typedef typename T1::elem_type eT;
  52. if(is_same_type<gen_type, gen_eye>::yes)
  53. {
  54. return (row == col) ? eT(1) : eT(0);
  55. }
  56. else
  57. {
  58. return (*this).generate();
  59. }
  60. }
  61. template<typename T1, typename gen_type>
  62. arma_inline
  63. typename T1::elem_type
  64. Gen<T1, gen_type>::at_alt(const uword ii) const
  65. {
  66. return operator[](ii);
  67. }
  68. template<typename T1, typename gen_type>
  69. inline
  70. void
  71. Gen<T1, gen_type>::apply(Mat<typename T1::elem_type>& out) const
  72. {
  73. arma_extra_debug_sigprint();
  74. // NOTE: we're assuming that the matrix has already been set to the correct size;
  75. // this is done by either the Mat contructor or operator=()
  76. if(is_same_type<gen_type, gen_eye >::yes) { out.eye(); }
  77. else if(is_same_type<gen_type, gen_ones >::yes) { out.ones(); }
  78. else if(is_same_type<gen_type, gen_zeros>::yes) { out.zeros(); }
  79. else if(is_same_type<gen_type, gen_randu>::yes) { out.randu(); }
  80. else if(is_same_type<gen_type, gen_randn>::yes) { out.randn(); }
  81. }
  82. template<typename T1, typename gen_type>
  83. inline
  84. void
  85. Gen<T1, gen_type>::apply_inplace_plus(Mat<typename T1::elem_type>& out) const
  86. {
  87. arma_extra_debug_sigprint();
  88. arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "addition");
  89. typedef typename T1::elem_type eT;
  90. if(is_same_type<gen_type, gen_eye>::yes)
  91. {
  92. const uword N = (std::min)(n_rows, n_cols);
  93. for(uword iq=0; iq < N; ++iq)
  94. {
  95. out.at(iq,iq) += eT(1);
  96. }
  97. }
  98. else
  99. {
  100. eT* out_mem = out.memptr();
  101. const uword n_elem = out.n_elem;
  102. uword iq,jq;
  103. for(iq=0, jq=1; jq < n_elem; iq+=2, jq+=2)
  104. {
  105. const eT tmp_i = (*this).generate();
  106. const eT tmp_j = (*this).generate();
  107. out_mem[iq] += tmp_i;
  108. out_mem[jq] += tmp_j;
  109. }
  110. if(iq < n_elem)
  111. {
  112. out_mem[iq] += (*this).generate();
  113. }
  114. }
  115. }
  116. template<typename T1, typename gen_type>
  117. inline
  118. void
  119. Gen<T1, gen_type>::apply_inplace_minus(Mat<typename T1::elem_type>& out) const
  120. {
  121. arma_extra_debug_sigprint();
  122. arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "subtraction");
  123. typedef typename T1::elem_type eT;
  124. if(is_same_type<gen_type, gen_eye>::yes)
  125. {
  126. const uword N = (std::min)(n_rows, n_cols);
  127. for(uword iq=0; iq < N; ++iq)
  128. {
  129. out.at(iq,iq) -= eT(1);
  130. }
  131. }
  132. else
  133. {
  134. eT* out_mem = out.memptr();
  135. const uword n_elem = out.n_elem;
  136. uword iq,jq;
  137. for(iq=0, jq=1; jq < n_elem; iq+=2, jq+=2)
  138. {
  139. const eT tmp_i = (*this).generate();
  140. const eT tmp_j = (*this).generate();
  141. out_mem[iq] -= tmp_i;
  142. out_mem[jq] -= tmp_j;
  143. }
  144. if(iq < n_elem)
  145. {
  146. out_mem[iq] -= (*this).generate();
  147. }
  148. }
  149. }
  150. template<typename T1, typename gen_type>
  151. inline
  152. void
  153. Gen<T1, gen_type>::apply_inplace_schur(Mat<typename T1::elem_type>& out) const
  154. {
  155. arma_extra_debug_sigprint();
  156. arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "element-wise multiplication");
  157. typedef typename T1::elem_type eT;
  158. if(is_same_type<gen_type, gen_eye>::yes)
  159. {
  160. const uword N = (std::min)(n_rows, n_cols);
  161. for(uword iq=0; iq < N; ++iq)
  162. {
  163. for(uword row=0; row < iq; ++row) { out.at(row,iq) = eT(0); }
  164. for(uword row=iq+1; row < n_rows; ++row) { out.at(row,iq) = eT(0); }
  165. }
  166. }
  167. else
  168. {
  169. eT* out_mem = out.memptr();
  170. const uword n_elem = out.n_elem;
  171. uword iq,jq;
  172. for(iq=0, jq=1; jq < n_elem; iq+=2, jq+=2)
  173. {
  174. const eT tmp_i = (*this).generate();
  175. const eT tmp_j = (*this).generate();
  176. out_mem[iq] *= tmp_i;
  177. out_mem[jq] *= tmp_j;
  178. }
  179. if(iq < n_elem)
  180. {
  181. out_mem[iq] *= (*this).generate();
  182. }
  183. }
  184. }
  185. template<typename T1, typename gen_type>
  186. inline
  187. void
  188. Gen<T1, gen_type>::apply_inplace_div(Mat<typename T1::elem_type>& out) const
  189. {
  190. arma_extra_debug_sigprint();
  191. arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "element-wise division");
  192. typedef typename T1::elem_type eT;
  193. if(is_same_type<gen_type, gen_eye>::yes)
  194. {
  195. const uword N = (std::min)(n_rows, n_cols);
  196. for(uword iq=0; iq < N; ++iq)
  197. {
  198. const eT zero = eT(0);
  199. for(uword row=0; row < iq; ++row) { out.at(row,iq) /= zero; }
  200. for(uword row=iq+1; row < n_rows; ++row) { out.at(row,iq) /= zero; }
  201. }
  202. }
  203. else
  204. {
  205. eT* out_mem = out.memptr();
  206. const uword n_elem = out.n_elem;
  207. uword iq,jq;
  208. for(iq=0, jq=1; jq < n_elem; iq+=2, jq+=2)
  209. {
  210. const eT tmp_i = (*this).generate();
  211. const eT tmp_j = (*this).generate();
  212. out_mem[iq] /= tmp_i;
  213. out_mem[jq] /= tmp_j;
  214. }
  215. if(iq < n_elem)
  216. {
  217. out_mem[iq] /= (*this).generate();
  218. }
  219. }
  220. }
  221. template<typename T1, typename gen_type>
  222. inline
  223. void
  224. Gen<T1, gen_type>::apply(subview<typename T1::elem_type>& out) const
  225. {
  226. arma_extra_debug_sigprint();
  227. // NOTE: we're assuming that the submatrix has the same dimensions as the Gen object
  228. // this is checked by subview::operator=()
  229. if(is_same_type<gen_type, gen_eye >::yes) { out.eye(); }
  230. else if(is_same_type<gen_type, gen_ones >::yes) { out.ones(); }
  231. else if(is_same_type<gen_type, gen_zeros>::yes) { out.zeros(); }
  232. else if(is_same_type<gen_type, gen_randu>::yes) { out.randu(); }
  233. else if(is_same_type<gen_type, gen_randn>::yes) { out.randn(); }
  234. }
  235. //! @}