glue_conv_meat.hpp 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup glue_conv
  16. //! @{
  17. // TODO: this implementation of conv() is rudimentary; replace with faster version
  18. template<typename eT>
  19. inline
  20. void
  21. glue_conv::apply(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const bool A_is_col)
  22. {
  23. arma_extra_debug_sigprint();
  24. const Mat<eT>& h = (A.n_elem <= B.n_elem) ? A : B;
  25. const Mat<eT>& x = (A.n_elem <= B.n_elem) ? B : A;
  26. const uword h_n_elem = h.n_elem;
  27. const uword h_n_elem_m1 = h_n_elem - 1;
  28. const uword x_n_elem = x.n_elem;
  29. const uword out_n_elem = ((h_n_elem + x_n_elem) > 0) ? (h_n_elem + x_n_elem - 1) : uword(0);
  30. if( (h_n_elem == 0) || (x_n_elem == 0) ) { out.zeros(); return; }
  31. Col<eT> hh(h_n_elem); // flipped version of h
  32. const eT* h_mem = h.memptr();
  33. eT* hh_mem = hh.memptr();
  34. for(uword i=0; i < h_n_elem; ++i)
  35. {
  36. hh_mem[h_n_elem_m1-i] = h_mem[i];
  37. }
  38. Col<eT> xx( (x_n_elem + 2*h_n_elem_m1), fill::zeros ); // zero padded version of x
  39. const eT* x_mem = x.memptr();
  40. eT* xx_mem = xx.memptr();
  41. arrayops::copy( &(xx_mem[h_n_elem_m1]), x_mem, x_n_elem );
  42. (A_is_col) ? out.set_size(out_n_elem, 1) : out.set_size(1, out_n_elem);
  43. eT* out_mem = out.memptr();
  44. for(uword i=0; i < out_n_elem; ++i)
  45. {
  46. // out_mem[i] = dot( hh, xx.subvec(i, (i + h_n_elem_m1)) );
  47. out_mem[i] = op_dot::direct_dot( h_n_elem, hh_mem, &(xx_mem[i]) );
  48. }
  49. }
  50. // // alternative implementation of 1d convolution
  51. // template<typename eT>
  52. // inline
  53. // void
  54. // glue_conv::apply(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const bool A_is_col)
  55. // {
  56. // arma_extra_debug_sigprint();
  57. //
  58. // const Mat<eT>& h = (A.n_elem <= B.n_elem) ? A : B;
  59. // const Mat<eT>& x = (A.n_elem <= B.n_elem) ? B : A;
  60. //
  61. // const uword h_n_elem = h.n_elem;
  62. // const uword h_n_elem_m1 = h_n_elem - 1;
  63. // const uword x_n_elem = x.n_elem;
  64. // const uword out_n_elem = ((h_n_elem + x_n_elem) > 0) ? (h_n_elem + x_n_elem - 1) : uword(0);
  65. //
  66. // if( (h_n_elem == 0) || (x_n_elem == 0) ) { out.zeros(); return; }
  67. //
  68. //
  69. // Col<eT> hh(h_n_elem); // flipped version of h
  70. //
  71. // const eT* h_mem = h.memptr();
  72. // eT* hh_mem = hh.memptr();
  73. //
  74. // for(uword i=0; i < h_n_elem; ++i)
  75. // {
  76. // hh_mem[h_n_elem_m1-i] = h_mem[i];
  77. // }
  78. //
  79. // // construct HH matrix, with the column containing shifted versions of hh;
  80. // // upper limit for number of zeros is about 50%; may not be optimal
  81. // const uword N_copies = (std::min)(uword(10), h_n_elem);
  82. //
  83. // const uword HH_n_rows = h_n_elem + (N_copies-1);
  84. //
  85. // Mat<eT> HH(HH_n_rows, N_copies, fill::zeros);
  86. //
  87. // for(uword i=0; i<N_copies; ++i)
  88. // {
  89. // arrayops::copy(HH.colptr(i) + i, hh.memptr(), h_n_elem);
  90. // }
  91. //
  92. //
  93. //
  94. // Col<eT> xx( (x_n_elem + 2*h_n_elem_m1), fill::zeros ); // zero padded version of x
  95. //
  96. // const eT* x_mem = x.memptr();
  97. // eT* xx_mem = xx.memptr();
  98. //
  99. // arrayops::copy( &(xx_mem[h_n_elem_m1]), x_mem, x_n_elem );
  100. //
  101. //
  102. // (A_is_col) ? out.set_size(out_n_elem, 1) : out.set_size(1, out_n_elem);
  103. //
  104. // eT* out_mem = out.memptr();
  105. //
  106. // uword last_i = 0;
  107. // bool last_i_done = false;
  108. //
  109. // for(uword i=0; i < xx.n_elem; i += N_copies)
  110. // {
  111. // if( ((i + HH_n_rows) <= xx.n_elem) && ((i + N_copies) <= out_n_elem) )
  112. // {
  113. // const Row<eT> xx_sub(xx_mem + i, HH_n_rows, false, true);
  114. //
  115. // Row<eT> out_sub(out_mem + i, N_copies, false, true);
  116. //
  117. // out_sub = xx_sub * HH;
  118. //
  119. // last_i_done = true;
  120. // }
  121. // else
  122. // {
  123. // last_i = i;
  124. // last_i_done = false;
  125. // break;
  126. // }
  127. // }
  128. //
  129. // if(last_i_done == false)
  130. // {
  131. // for(uword i=last_i; i < out_n_elem; ++i)
  132. // {
  133. // // out_mem[i] = dot( hh, xx.subvec(i, (i + h_n_elem_m1)) );
  134. //
  135. // out_mem[i] = op_dot::direct_dot( h_n_elem, hh_mem, &(xx_mem[i]) );
  136. // }
  137. // }
  138. // }
  139. template<typename T1, typename T2>
  140. inline
  141. void
  142. glue_conv::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_conv>& expr)
  143. {
  144. arma_extra_debug_sigprint();
  145. typedef typename T1::elem_type eT;
  146. const quasi_unwrap<T1> UA(expr.A);
  147. const quasi_unwrap<T2> UB(expr.B);
  148. const Mat<eT>& A = UA.M;
  149. const Mat<eT>& B = UB.M;
  150. arma_debug_check
  151. (
  152. ( ((A.is_vec() == false) && (A.is_empty() == false)) || ((B.is_vec() == false) && (B.is_empty() == false)) ),
  153. "conv(): given object is not a vector"
  154. );
  155. const bool A_is_col = ((T1::is_col) || (A.n_cols == 1));
  156. const uword mode = expr.aux_uword;
  157. if(mode == 0) // full convolution
  158. {
  159. glue_conv::apply(out, A, B, A_is_col);
  160. }
  161. else
  162. if(mode == 1) // same size as A
  163. {
  164. Mat<eT> tmp;
  165. glue_conv::apply(tmp, A, B, A_is_col);
  166. if( (tmp.is_empty() == false) && (A.is_empty() == false) && (B.is_empty() == false) )
  167. {
  168. const uword start = uword( std::floor( double(B.n_elem) / double(2) ) );
  169. out = (A_is_col) ? tmp(start, 0, arma::size(A)) : tmp(0, start, arma::size(A));
  170. }
  171. else
  172. {
  173. out.zeros( arma::size(A) );
  174. }
  175. }
  176. }
  177. ///
  178. // TODO: this implementation of conv2() is rudimentary; replace with faster version
  179. template<typename eT>
  180. inline
  181. void
  182. glue_conv2::apply(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B)
  183. {
  184. arma_extra_debug_sigprint();
  185. const Mat<eT>& G = (A.n_elem <= B.n_elem) ? A : B; // unflipped filter coefficients
  186. const Mat<eT>& W = (A.n_elem <= B.n_elem) ? B : A; // original 2D image
  187. const uword out_n_rows = ((W.n_rows + G.n_rows) > 0) ? (W.n_rows + G.n_rows - 1) : uword(0);
  188. const uword out_n_cols = ((W.n_cols + G.n_cols) > 0) ? (W.n_cols + G.n_cols - 1) : uword(0);
  189. if(G.is_empty() || W.is_empty()) { out.zeros(); return; }
  190. Mat<eT> H(G.n_rows, G.n_cols); // flipped filter coefficients
  191. const uword H_n_rows = H.n_rows;
  192. const uword H_n_cols = H.n_cols;
  193. const uword H_n_rows_m1 = H_n_rows - 1;
  194. const uword H_n_cols_m1 = H_n_cols - 1;
  195. for(uword col=0; col < H_n_cols; ++col)
  196. {
  197. eT* H_colptr = H.colptr(H_n_cols_m1 - col);
  198. const eT* G_colptr = G.colptr(col);
  199. for(uword row=0; row < H_n_rows; ++row)
  200. {
  201. H_colptr[H_n_rows_m1 - row] = G_colptr[row];
  202. }
  203. }
  204. Mat<eT> X( (W.n_rows + 2*H_n_rows_m1), (W.n_cols + 2*H_n_cols_m1), fill::zeros );
  205. X( H_n_rows_m1, H_n_cols_m1, arma::size(W) ) = W; // zero padded version of 2D image
  206. out.set_size( out_n_rows, out_n_cols );
  207. for(uword col=0; col < out_n_cols; ++col)
  208. {
  209. eT* out_colptr = out.colptr(col);
  210. for(uword row=0; row < out_n_rows; ++row)
  211. {
  212. // out.at(row, col) = accu( H % X(row, col, size(H)) );
  213. eT acc = eT(0);
  214. for(uword H_col = 0; H_col < H_n_cols; ++H_col)
  215. {
  216. const eT* X_colptr = X.colptr(col + H_col);
  217. acc += op_dot::direct_dot( H_n_rows, H.colptr(H_col), &(X_colptr[row]) );
  218. }
  219. out_colptr[row] = acc;
  220. }
  221. }
  222. }
  223. template<typename T1, typename T2>
  224. inline
  225. void
  226. glue_conv2::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_conv2>& expr)
  227. {
  228. arma_extra_debug_sigprint();
  229. typedef typename T1::elem_type eT;
  230. const quasi_unwrap<T1> UA(expr.A);
  231. const quasi_unwrap<T2> UB(expr.B);
  232. const Mat<eT>& A = UA.M;
  233. const Mat<eT>& B = UB.M;
  234. const uword mode = expr.aux_uword;
  235. if(mode == 0) // full convolution
  236. {
  237. glue_conv2::apply(out, A, B);
  238. }
  239. else
  240. if(mode == 1) // same size as A
  241. {
  242. Mat<eT> tmp;
  243. glue_conv2::apply(tmp, A, B);
  244. if( (tmp.is_empty() == false) && (A.is_empty() == false) && (B.is_empty() == false) )
  245. {
  246. const uword start_row = uword( std::floor( double(B.n_rows) / double(2) ) );
  247. const uword start_col = uword( std::floor( double(B.n_cols) / double(2) ) );
  248. out = tmp(start_row, start_col, arma::size(A));
  249. }
  250. else
  251. {
  252. out.zeros( arma::size(A) );
  253. }
  254. }
  255. }
  256. //! @}