op_shift_meat.hpp 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup op_shift
  16. //! @{
  17. template<typename T1>
  18. inline
  19. void
  20. op_shift_vec::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_shift_vec>& in)
  21. {
  22. arma_extra_debug_sigprint();
  23. const unwrap<T1> U(in.m);
  24. const uword len = in.aux_uword_a;
  25. const uword neg = in.aux_uword_b;
  26. const uword dim = (T1::is_xvec) ? uword(U.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0);
  27. op_shift::apply_direct(out, U.M, len, neg, dim);
  28. }
  29. template<typename T1>
  30. inline
  31. void
  32. op_shift::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_shift>& in)
  33. {
  34. arma_extra_debug_sigprint();
  35. const unwrap<T1> U(in.m);
  36. const uword len = in.aux_uword_a;
  37. const uword neg = in.aux_uword_b;
  38. const uword dim = in.aux_uword_c;
  39. arma_debug_check( (dim > 1), "shift(): parameter 'dim' must be 0 or 1" );
  40. op_shift::apply_direct(out, U.M, len, neg, dim);
  41. }
  42. template<typename eT>
  43. inline
  44. void
  45. op_shift::apply_direct(Mat<eT>& out, const Mat<eT>& X, const uword len, const uword neg, const uword dim)
  46. {
  47. arma_extra_debug_sigprint();
  48. arma_debug_check( ((dim == 0) && (len >= X.n_rows)), "shift(): shift amount out of bounds" );
  49. arma_debug_check( ((dim == 1) && (len >= X.n_cols)), "shift(): shift amount out of bounds" );
  50. if(&out == &X)
  51. {
  52. op_shift::apply_alias(out, len, neg, dim);
  53. }
  54. else
  55. {
  56. op_shift::apply_noalias(out, X, len, neg, dim);
  57. }
  58. }
  59. template<typename eT>
  60. inline
  61. void
  62. op_shift::apply_noalias(Mat<eT>& out, const Mat<eT>& X, const uword len, const uword neg, const uword dim)
  63. {
  64. arma_extra_debug_sigprint();
  65. out.copy_size(X);
  66. const uword X_n_rows = X.n_rows;
  67. const uword X_n_cols = X.n_cols;
  68. if(dim == 0)
  69. {
  70. if(neg == 0)
  71. {
  72. for(uword col=0; col < X_n_cols; ++col)
  73. {
  74. eT* out_ptr = out.colptr(col);
  75. const eT* X_ptr = X.colptr(col);
  76. for(uword out_row=len, row=0; row < (X_n_rows - len); ++row, ++out_row)
  77. {
  78. out_ptr[out_row] = X_ptr[row];
  79. }
  80. for(uword out_row=0, row=(X_n_rows - len); row < X_n_rows; ++row, ++out_row)
  81. {
  82. out_ptr[out_row] = X_ptr[row];
  83. }
  84. }
  85. }
  86. else
  87. if(neg == 1)
  88. {
  89. for(uword col=0; col < X_n_cols; ++col)
  90. {
  91. eT* out_ptr = out.colptr(col);
  92. const eT* X_ptr = X.colptr(col);
  93. for(uword out_row=0, row=len; row < X_n_rows; ++row, ++out_row)
  94. {
  95. out_ptr[out_row] = X_ptr[row];
  96. }
  97. for(uword out_row=(X_n_rows-len), row=0; row < len; ++row, ++out_row)
  98. {
  99. out_ptr[out_row] = X_ptr[row];
  100. }
  101. }
  102. }
  103. }
  104. else
  105. if(dim == 1)
  106. {
  107. if(neg == 0)
  108. {
  109. if(X_n_rows == 1)
  110. {
  111. eT* out_ptr = out.memptr();
  112. const eT* X_ptr = X.memptr();
  113. for(uword out_col=len, col=0; col < (X_n_cols - len); ++col, ++out_col)
  114. {
  115. out_ptr[out_col] = X_ptr[col];
  116. }
  117. for(uword out_col=0, col=(X_n_cols - len); col < X_n_cols; ++col, ++out_col)
  118. {
  119. out_ptr[out_col] = X_ptr[col];
  120. }
  121. }
  122. else
  123. {
  124. for(uword out_col=len, col=0; col < (X_n_cols - len); ++col, ++out_col)
  125. {
  126. arrayops::copy( out.colptr(out_col), X.colptr(col), X_n_rows );
  127. }
  128. for(uword out_col=0, col=(X_n_cols - len); col < X_n_cols; ++col, ++out_col)
  129. {
  130. arrayops::copy( out.colptr(out_col), X.colptr(col), X_n_rows );
  131. }
  132. }
  133. }
  134. else
  135. if(neg == 1)
  136. {
  137. if(X_n_rows == 1)
  138. {
  139. eT* out_ptr = out.memptr();
  140. const eT* X_ptr = X.memptr();
  141. for(uword out_col=0, col=len; col < X_n_cols; ++col, ++out_col)
  142. {
  143. out_ptr[out_col] = X_ptr[col];
  144. }
  145. for(uword out_col=(X_n_cols-len), col=0; col < len; ++col, ++out_col)
  146. {
  147. out_ptr[out_col] = X_ptr[col];
  148. }
  149. }
  150. else
  151. {
  152. for(uword out_col=0, col=len; col < X_n_cols; ++col, ++out_col)
  153. {
  154. arrayops::copy( out.colptr(out_col), X.colptr(col), X_n_rows );
  155. }
  156. for(uword out_col=(X_n_cols-len), col=0; col < len; ++col, ++out_col)
  157. {
  158. arrayops::copy( out.colptr(out_col), X.colptr(col), X_n_rows );
  159. }
  160. }
  161. }
  162. }
  163. }
  164. template<typename eT>
  165. inline
  166. void
  167. op_shift::apply_alias(Mat<eT>& X, const uword len, const uword neg, const uword dim)
  168. {
  169. arma_extra_debug_sigprint();
  170. // TODO: replace with better implementation
  171. Mat<eT> tmp;
  172. op_shift::apply_noalias(tmp, X, len, neg, dim);
  173. X.steal_mem(tmp);
  174. }
  175. //! @}