op_strans_meat.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup op_strans
  16. //! @{
  17. //! for tiny square matrices (size <= 4x4)
  18. template<typename eT, typename TA>
  19. arma_cold
  20. inline
  21. void
  22. op_strans::apply_mat_noalias_tinysq(Mat<eT>& out, const TA& A)
  23. {
  24. const eT* Am = A.memptr();
  25. eT* outm = out.memptr();
  26. switch(A.n_rows)
  27. {
  28. case 1:
  29. {
  30. outm[0] = Am[0];
  31. }
  32. break;
  33. case 2:
  34. {
  35. outm[pos<false,0,0>::n2] = Am[pos<true,0,0>::n2];
  36. outm[pos<false,1,0>::n2] = Am[pos<true,1,0>::n2];
  37. outm[pos<false,0,1>::n2] = Am[pos<true,0,1>::n2];
  38. outm[pos<false,1,1>::n2] = Am[pos<true,1,1>::n2];
  39. }
  40. break;
  41. case 3:
  42. {
  43. outm[pos<false,0,0>::n3] = Am[pos<true,0,0>::n3];
  44. outm[pos<false,1,0>::n3] = Am[pos<true,1,0>::n3];
  45. outm[pos<false,2,0>::n3] = Am[pos<true,2,0>::n3];
  46. outm[pos<false,0,1>::n3] = Am[pos<true,0,1>::n3];
  47. outm[pos<false,1,1>::n3] = Am[pos<true,1,1>::n3];
  48. outm[pos<false,2,1>::n3] = Am[pos<true,2,1>::n3];
  49. outm[pos<false,0,2>::n3] = Am[pos<true,0,2>::n3];
  50. outm[pos<false,1,2>::n3] = Am[pos<true,1,2>::n3];
  51. outm[pos<false,2,2>::n3] = Am[pos<true,2,2>::n3];
  52. }
  53. break;
  54. case 4:
  55. {
  56. outm[pos<false,0,0>::n4] = Am[pos<true,0,0>::n4];
  57. outm[pos<false,1,0>::n4] = Am[pos<true,1,0>::n4];
  58. outm[pos<false,2,0>::n4] = Am[pos<true,2,0>::n4];
  59. outm[pos<false,3,0>::n4] = Am[pos<true,3,0>::n4];
  60. outm[pos<false,0,1>::n4] = Am[pos<true,0,1>::n4];
  61. outm[pos<false,1,1>::n4] = Am[pos<true,1,1>::n4];
  62. outm[pos<false,2,1>::n4] = Am[pos<true,2,1>::n4];
  63. outm[pos<false,3,1>::n4] = Am[pos<true,3,1>::n4];
  64. outm[pos<false,0,2>::n4] = Am[pos<true,0,2>::n4];
  65. outm[pos<false,1,2>::n4] = Am[pos<true,1,2>::n4];
  66. outm[pos<false,2,2>::n4] = Am[pos<true,2,2>::n4];
  67. outm[pos<false,3,2>::n4] = Am[pos<true,3,2>::n4];
  68. outm[pos<false,0,3>::n4] = Am[pos<true,0,3>::n4];
  69. outm[pos<false,1,3>::n4] = Am[pos<true,1,3>::n4];
  70. outm[pos<false,2,3>::n4] = Am[pos<true,2,3>::n4];
  71. outm[pos<false,3,3>::n4] = Am[pos<true,3,3>::n4];
  72. }
  73. break;
  74. default:
  75. ;
  76. }
  77. }
  78. template<typename eT>
  79. arma_hot
  80. inline
  81. void
  82. op_strans::block_worker(eT* Y, const eT* X, const uword X_n_rows, const uword Y_n_rows, const uword n_rows, const uword n_cols)
  83. {
  84. for(uword row = 0; row < n_rows; ++row)
  85. {
  86. const uword Y_offset = row * Y_n_rows;
  87. for(uword col = 0; col < n_cols; ++col)
  88. {
  89. const uword X_offset = col * X_n_rows;
  90. Y[col + Y_offset] = X[row + X_offset];
  91. }
  92. }
  93. }
  94. template<typename eT>
  95. arma_hot
  96. inline
  97. void
  98. op_strans::apply_mat_noalias_large(Mat<eT>& out, const Mat<eT>& A)
  99. {
  100. arma_extra_debug_sigprint();
  101. const uword n_rows = A.n_rows;
  102. const uword n_cols = A.n_cols;
  103. const uword block_size = 64;
  104. const uword n_rows_base = block_size * (n_rows / block_size);
  105. const uword n_cols_base = block_size * (n_cols / block_size);
  106. const uword n_rows_extra = n_rows - n_rows_base;
  107. const uword n_cols_extra = n_cols - n_cols_base;
  108. const eT* X = A.memptr();
  109. eT* Y = out.memptr();
  110. for(uword row = 0; row < n_rows_base; row += block_size)
  111. {
  112. const uword Y_offset = row * n_cols;
  113. for(uword col = 0; col < n_cols_base; col += block_size)
  114. {
  115. const uword X_offset = col * n_rows;
  116. op_strans::block_worker(&Y[col + Y_offset], &X[row + X_offset], n_rows, n_cols, block_size, block_size);
  117. }
  118. const uword X_offset = n_cols_base * n_rows;
  119. op_strans::block_worker(&Y[n_cols_base + Y_offset], &X[row + X_offset], n_rows, n_cols, block_size, n_cols_extra);
  120. }
  121. if(n_rows_extra == 0) { return; }
  122. const uword Y_offset = n_rows_base * n_cols;
  123. for(uword col = 0; col < n_cols_base; col += block_size)
  124. {
  125. const uword X_offset = col * n_rows;
  126. op_strans::block_worker(&Y[col + Y_offset], &X[n_rows_base + X_offset], n_rows, n_cols, n_rows_extra, block_size);
  127. }
  128. const uword X_offset = n_cols_base * n_rows;
  129. op_strans::block_worker(&Y[n_cols_base + Y_offset], &X[n_rows_base + X_offset], n_rows, n_cols, n_rows_extra, n_cols_extra);
  130. }
  131. //! Immediate transpose of a dense matrix
  132. template<typename eT, typename TA>
  133. arma_hot
  134. inline
  135. void
  136. op_strans::apply_mat_noalias(Mat<eT>& out, const TA& A)
  137. {
  138. arma_extra_debug_sigprint();
  139. const uword A_n_cols = A.n_cols;
  140. const uword A_n_rows = A.n_rows;
  141. out.set_size(A_n_cols, A_n_rows);
  142. if( (TA::is_row) || (TA::is_col) || (A_n_cols == 1) || (A_n_rows == 1) )
  143. {
  144. arrayops::copy( out.memptr(), A.memptr(), A.n_elem );
  145. }
  146. else
  147. {
  148. if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) )
  149. {
  150. op_strans::apply_mat_noalias_tinysq(out, A);
  151. }
  152. else
  153. if( (A_n_rows >= 512) && (A_n_cols >= 512) )
  154. {
  155. op_strans::apply_mat_noalias_large(out, A);
  156. }
  157. else
  158. {
  159. eT* outptr = out.memptr();
  160. for(uword k=0; k < A_n_rows; ++k)
  161. {
  162. const eT* Aptr = &(A.at(k,0));
  163. uword j;
  164. for(j=1; j < A_n_cols; j+=2)
  165. {
  166. const eT tmp_i = (*Aptr); Aptr += A_n_rows;
  167. const eT tmp_j = (*Aptr); Aptr += A_n_rows;
  168. (*outptr) = tmp_i; outptr++;
  169. (*outptr) = tmp_j; outptr++;
  170. }
  171. if((j-1) < A_n_cols)
  172. {
  173. (*outptr) = (*Aptr); outptr++;;
  174. }
  175. }
  176. }
  177. }
  178. }
  179. template<typename eT>
  180. arma_hot
  181. inline
  182. void
  183. op_strans::apply_mat_inplace(Mat<eT>& out)
  184. {
  185. arma_extra_debug_sigprint();
  186. const uword n_rows = out.n_rows;
  187. const uword n_cols = out.n_cols;
  188. if(n_rows == n_cols)
  189. {
  190. arma_extra_debug_print("op_strans::apply(): doing in-place transpose of a square matrix");
  191. const uword N = n_rows;
  192. for(uword k=0; k < N; ++k)
  193. {
  194. eT* colptr = &(out.at(k,k));
  195. eT* rowptr = colptr;
  196. colptr++;
  197. rowptr += N;
  198. uword j;
  199. for(j=(k+2); j < N; j+=2)
  200. {
  201. std::swap( (*rowptr), (*colptr) ); rowptr += N; colptr++;
  202. std::swap( (*rowptr), (*colptr) ); rowptr += N; colptr++;
  203. }
  204. if((j-1) < N)
  205. {
  206. std::swap( (*rowptr), (*colptr) );
  207. }
  208. }
  209. }
  210. else
  211. {
  212. Mat<eT> tmp;
  213. op_strans::apply_mat_noalias(tmp, out);
  214. out.steal_mem(tmp);
  215. }
  216. }
  217. template<typename eT, typename TA>
  218. arma_hot
  219. inline
  220. void
  221. op_strans::apply_mat(Mat<eT>& out, const TA& A)
  222. {
  223. arma_extra_debug_sigprint();
  224. if(&out != &A)
  225. {
  226. op_strans::apply_mat_noalias(out, A);
  227. }
  228. else
  229. {
  230. op_strans::apply_mat_inplace(out);
  231. }
  232. }
  233. template<typename T1>
  234. arma_hot
  235. inline
  236. void
  237. op_strans::apply_proxy(Mat<typename T1::elem_type>& out, const T1& X)
  238. {
  239. arma_extra_debug_sigprint();
  240. typedef typename T1::elem_type eT;
  241. const Proxy<T1> P(X);
  242. const uword n_rows = P.get_n_rows();
  243. const uword n_cols = P.get_n_cols();
  244. const bool is_alias = P.is_alias(out);
  245. if( (resolves_to_vector<T1>::yes) && (Proxy<T1>::use_at == false) )
  246. {
  247. if(is_alias == false)
  248. {
  249. out.set_size(n_cols, n_rows);
  250. eT* out_mem = out.memptr();
  251. const uword n_elem = P.get_n_elem();
  252. typename Proxy<T1>::ea_type Pea = P.get_ea();
  253. uword i,j;
  254. for(i=0, j=1; j < n_elem; i+=2, j+=2)
  255. {
  256. const eT tmp_i = Pea[i];
  257. const eT tmp_j = Pea[j];
  258. out_mem[i] = tmp_i;
  259. out_mem[j] = tmp_j;
  260. }
  261. if(i < n_elem)
  262. {
  263. out_mem[i] = Pea[i];
  264. }
  265. }
  266. else // aliasing
  267. {
  268. Mat<eT> out2(n_cols, n_rows);
  269. eT* out_mem = out2.memptr();
  270. const uword n_elem = P.get_n_elem();
  271. typename Proxy<T1>::ea_type Pea = P.get_ea();
  272. uword i,j;
  273. for(i=0, j=1; j < n_elem; i+=2, j+=2)
  274. {
  275. const eT tmp_i = Pea[i];
  276. const eT tmp_j = Pea[j];
  277. out_mem[i] = tmp_i;
  278. out_mem[j] = tmp_j;
  279. }
  280. if(i < n_elem)
  281. {
  282. out_mem[i] = Pea[i];
  283. }
  284. out.steal_mem(out2);
  285. }
  286. }
  287. else // general matrix transpose
  288. {
  289. if(is_alias == false)
  290. {
  291. out.set_size(n_cols, n_rows);
  292. eT* outptr = out.memptr();
  293. for(uword k=0; k < n_rows; ++k)
  294. {
  295. uword j;
  296. for(j=1; j < n_cols; j+=2)
  297. {
  298. const uword i = j-1;
  299. const eT tmp_i = P.at(k,i);
  300. const eT tmp_j = P.at(k,j);
  301. (*outptr) = tmp_i; outptr++;
  302. (*outptr) = tmp_j; outptr++;
  303. }
  304. const uword i = j-1;
  305. if(i < n_cols)
  306. {
  307. (*outptr) = P.at(k,i); outptr++;
  308. }
  309. }
  310. }
  311. else // aliasing
  312. {
  313. Mat<eT> out2(n_cols, n_rows);
  314. eT* out2ptr = out2.memptr();
  315. for(uword k=0; k < n_rows; ++k)
  316. {
  317. uword j;
  318. for(j=1; j < n_cols; j+=2)
  319. {
  320. const uword i = j-1;
  321. const eT tmp_i = P.at(k,i);
  322. const eT tmp_j = P.at(k,j);
  323. (*out2ptr) = tmp_i; out2ptr++;
  324. (*out2ptr) = tmp_j; out2ptr++;
  325. }
  326. const uword i = j-1;
  327. if(i < n_cols)
  328. {
  329. (*out2ptr) = P.at(k,i); out2ptr++;
  330. }
  331. }
  332. out.steal_mem(out2);
  333. }
  334. }
  335. }
  336. template<typename T1>
  337. arma_hot
  338. inline
  339. void
  340. op_strans::apply_direct(Mat<typename T1::elem_type>& out, const T1& X)
  341. {
  342. arma_extra_debug_sigprint();
  343. // allow detection of in-place transpose
  344. if(is_Mat<T1>::value || is_Mat<typename Proxy<T1>::stored_type>::value)
  345. {
  346. const unwrap<T1> U(X);
  347. op_strans::apply_mat(out, U.M);
  348. }
  349. else
  350. {
  351. op_strans::apply_proxy(out, X);
  352. }
  353. }
  354. template<typename T1>
  355. arma_hot
  356. inline
  357. void
  358. op_strans::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_strans>& in)
  359. {
  360. arma_extra_debug_sigprint();
  361. op_strans::apply_direct(out, in.m);
  362. }
  363. //
  364. //
  365. //
  366. template<typename eT>
  367. inline
  368. void
  369. op_strans_cube::apply_noalias(Cube<eT>& out, const Cube<eT>& X)
  370. {
  371. out.set_size(X.n_cols, X.n_rows, X.n_slices);
  372. for(uword s=0; s < X.n_slices; ++s)
  373. {
  374. Mat<eT> out_slice( out.slice_memptr(s), X.n_cols, X.n_rows, false, true );
  375. const Mat<eT> X_slice( const_cast<eT*>(X.slice_memptr(s)), X.n_rows, X.n_cols, false, true );
  376. op_strans::apply_mat_noalias(out_slice, X_slice);
  377. }
  378. }
  379. //! @}