op_diagmat_meat.hpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup op_diagmat
  16. //! @{
  17. template<typename T1>
  18. inline
  19. void
  20. op_diagmat::apply(Mat<typename T1::elem_type>& out, const Op<T1, op_diagmat>& X)
  21. {
  22. arma_extra_debug_sigprint();
  23. typedef typename T1::elem_type eT;
  24. if(is_Mat<T1>::value)
  25. {
  26. // allow detection of in-place operation
  27. const unwrap<T1> U(X.m);
  28. const Mat<eT>& A = U.M;
  29. if(&out != &A) // no aliasing
  30. {
  31. const Proxy< Mat<eT> > P(A);
  32. op_diagmat::apply(out, P);
  33. }
  34. else // we have aliasing
  35. {
  36. const uword n_rows = out.n_rows;
  37. const uword n_cols = out.n_cols;
  38. if((n_rows == 1) || (n_cols == 1)) // create diagonal matrix from vector
  39. {
  40. const eT* out_mem = out.memptr();
  41. const uword N = out.n_elem;
  42. Mat<eT> tmp(N,N); tmp.zeros();
  43. for(uword i=0; i<N; ++i) { tmp.at(i,i) = out_mem[i]; }
  44. out.steal_mem(tmp);
  45. }
  46. else // create diagonal matrix from matrix
  47. {
  48. const uword N = (std::min)(n_rows, n_cols);
  49. for(uword i=0; i < n_cols; ++i)
  50. {
  51. if(i < N)
  52. {
  53. eT& out_ii = out.at(i,i);
  54. const eT val = out_ii;
  55. arrayops::fill_zeros(out.colptr(i), n_rows);
  56. out_ii = val;
  57. }
  58. else
  59. {
  60. arrayops::fill_zeros(out.colptr(i), n_rows);
  61. }
  62. }
  63. }
  64. }
  65. }
  66. else
  67. {
  68. const Proxy<T1> P(X.m);
  69. if(P.is_alias(out))
  70. {
  71. Mat<eT> tmp;
  72. op_diagmat::apply(tmp, P);
  73. out.steal_mem(tmp);
  74. }
  75. else
  76. {
  77. op_diagmat::apply(out, P);
  78. }
  79. }
  80. }
  81. template<typename T1>
  82. inline
  83. void
  84. op_diagmat::apply(Mat<typename T1::elem_type>& out, const Proxy<T1>& P)
  85. {
  86. arma_extra_debug_sigprint();
  87. const uword n_rows = P.get_n_rows();
  88. const uword n_cols = P.get_n_cols();
  89. const uword n_elem = P.get_n_elem();
  90. if(n_elem == 0) { out.reset(); return; }
  91. const bool P_is_vec = (T1::is_row) || (T1::is_col) || (n_rows == 1) || (n_cols == 1);
  92. if(P_is_vec)
  93. {
  94. out.zeros(n_elem, n_elem);
  95. if(Proxy<T1>::use_at == false)
  96. {
  97. typename Proxy<T1>::ea_type Pea = P.get_ea();
  98. for(uword i=0; i < n_elem; ++i) { out.at(i,i) = Pea[i]; }
  99. }
  100. else
  101. {
  102. if(n_rows == 1)
  103. {
  104. for(uword i=0; i < n_elem; ++i) { out.at(i,i) = P.at(0,i); }
  105. }
  106. else
  107. {
  108. for(uword i=0; i < n_elem; ++i) { out.at(i,i) = P.at(i,0); }
  109. }
  110. }
  111. }
  112. else // P represents a matrix
  113. {
  114. out.zeros(n_rows, n_cols);
  115. const uword N = (std::min)(n_rows, n_cols);
  116. for(uword i=0; i<N; ++i) { out.at(i,i) = P.at(i,i); }
  117. }
  118. }
  119. template<typename T1, typename T2>
  120. inline
  121. void
  122. op_diagmat::apply(Mat<typename T1::elem_type>& out, const Op< Glue<T1,T2,glue_times>, op_diagmat>& X)
  123. {
  124. arma_extra_debug_sigprint();
  125. op_diagmat::apply_times(out, X.m.A, X.m.B);
  126. }
  127. template<typename T1, typename T2>
  128. inline
  129. void
  130. op_diagmat::apply_times(Mat<typename T1::elem_type>& actual_out, const T1& X, const T2& Y, const typename arma_not_cx<typename T1::elem_type>::result* junk)
  131. {
  132. arma_extra_debug_sigprint();
  133. arma_ignore(junk);
  134. typedef typename T1::elem_type eT;
  135. const partial_unwrap<T1> UA(X);
  136. const partial_unwrap<T2> UB(Y);
  137. const typename partial_unwrap<T1>::stored_type& A = UA.M;
  138. const typename partial_unwrap<T2>::stored_type& B = UB.M;
  139. arma_debug_assert_trans_mul_size< partial_unwrap<T1>::do_trans, partial_unwrap<T2>::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
  140. const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
  141. const eT alpha = use_alpha ? (UA.get_val() * UB.get_val()) : eT(0);
  142. const uword A_n_rows = A.n_rows;
  143. const uword A_n_cols = A.n_cols;
  144. const uword B_n_rows = B.n_rows;
  145. const uword B_n_cols = B.n_cols;
  146. // check if the multiplication results in a vector
  147. if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == false) )
  148. {
  149. if((A_n_rows == 1) || (B_n_cols == 1))
  150. {
  151. arma_extra_debug_print("trans_A = false; trans_B = false; vector result");
  152. const Mat<eT> C = A*B;
  153. const eT* C_mem = C.memptr();
  154. const uword N = C.n_elem;
  155. actual_out.zeros(N,N);
  156. for(uword i=0; i<N; ++i) { actual_out.at(i,i) = (use_alpha) ? eT(alpha * C_mem[i]) : eT(C_mem[i]); }
  157. return;
  158. }
  159. }
  160. else
  161. if( (partial_unwrap<T1>::do_trans == true ) && (partial_unwrap<T2>::do_trans == false) )
  162. {
  163. if((A_n_cols == 1) || (B_n_cols == 1))
  164. {
  165. arma_extra_debug_print("trans_A = true; trans_B = false; vector result");
  166. const Mat<eT> C = trans(A)*B;
  167. const eT* C_mem = C.memptr();
  168. const uword N = C.n_elem;
  169. actual_out.zeros(N,N);
  170. for(uword i=0; i<N; ++i) { actual_out.at(i,i) = (use_alpha) ? eT(alpha * C_mem[i]) : eT(C_mem[i]); }
  171. return;
  172. }
  173. }
  174. else
  175. if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == true ) )
  176. {
  177. if((A_n_rows == 1) || (B_n_rows == 1))
  178. {
  179. arma_extra_debug_print("trans_A = false; trans_B = true; vector result");
  180. const Mat<eT> C = A*trans(B);
  181. const eT* C_mem = C.memptr();
  182. const uword N = C.n_elem;
  183. actual_out.zeros(N,N);
  184. for(uword i=0; i<N; ++i) { actual_out.at(i,i) = (use_alpha) ? eT(alpha * C_mem[i]) : eT(C_mem[i]); }
  185. return;
  186. }
  187. }
  188. else
  189. if( (partial_unwrap<T1>::do_trans == true ) && (partial_unwrap<T2>::do_trans == true ) )
  190. {
  191. if((A_n_cols == 1) || (B_n_rows == 1))
  192. {
  193. arma_extra_debug_print("trans_A = true; trans_B = true; vector result");
  194. const Mat<eT> C = trans(A)*trans(B);
  195. const eT* C_mem = C.memptr();
  196. const uword N = C.n_elem;
  197. actual_out.zeros(N,N);
  198. for(uword i=0; i<N; ++i) { actual_out.at(i,i) = (use_alpha) ? eT(alpha * C_mem[i]) : eT(C_mem[i]); }
  199. return;
  200. }
  201. }
  202. // if we got to this point, the multiplication results in a matrix
  203. const bool is_alias = (UA.is_alias(actual_out) || UB.is_alias(actual_out));
  204. Mat<eT> tmp;
  205. Mat<eT>& out = (is_alias) ? tmp : actual_out;
  206. if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == false) )
  207. {
  208. arma_extra_debug_print("trans_A = false; trans_B = false; matrix result");
  209. out.zeros(A_n_rows, B_n_cols);
  210. const uword N = (std::min)(A_n_rows, B_n_cols);
  211. for(uword k=0; k < N; ++k)
  212. {
  213. eT acc1 = eT(0);
  214. eT acc2 = eT(0);
  215. const eT* B_colptr = B.colptr(k);
  216. // condition: A_n_cols = B_n_rows
  217. uword j;
  218. for(j=1; j < A_n_cols; j+=2)
  219. {
  220. const uword i = (j-1);
  221. const eT tmp_i = B_colptr[i];
  222. const eT tmp_j = B_colptr[j];
  223. acc1 += A.at(k, i) * tmp_i;
  224. acc2 += A.at(k, j) * tmp_j;
  225. }
  226. const uword i = (j-1);
  227. if(i < A_n_cols)
  228. {
  229. acc1 += A.at(k, i) * B_colptr[i];
  230. }
  231. const eT acc = acc1 + acc2;
  232. out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc);
  233. }
  234. }
  235. else
  236. if( (partial_unwrap<T1>::do_trans == true ) && (partial_unwrap<T2>::do_trans == false) )
  237. {
  238. arma_extra_debug_print("trans_A = true; trans_B = false; matrix result");
  239. out.zeros(A_n_cols, B_n_cols);
  240. const uword N = (std::min)(A_n_cols, B_n_cols);
  241. for(uword k=0; k < N; ++k)
  242. {
  243. const eT* A_colptr = A.colptr(k);
  244. const eT* B_colptr = B.colptr(k);
  245. // condition: A_n_rows = B_n_rows
  246. const eT acc = op_dot::direct_dot(A_n_rows, A_colptr, B_colptr);
  247. out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc);
  248. }
  249. }
  250. else
  251. if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == true ) )
  252. {
  253. arma_extra_debug_print("trans_A = false; trans_B = true; matrix result");
  254. out.zeros(A_n_rows, B_n_rows);
  255. const uword N = (std::min)(A_n_rows, B_n_rows);
  256. for(uword k=0; k < N; ++k)
  257. {
  258. eT acc = eT(0);
  259. // condition: A_n_cols = B_n_cols
  260. for(uword i=0; i < A_n_cols; ++i)
  261. {
  262. acc += A.at(k,i) * B.at(k,i);
  263. }
  264. out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc);
  265. }
  266. }
  267. else
  268. if( (partial_unwrap<T1>::do_trans == true ) && (partial_unwrap<T2>::do_trans == true ) )
  269. {
  270. arma_extra_debug_print("trans_A = true; trans_B = true; matrix result");
  271. out.zeros(A_n_cols, B_n_rows);
  272. const uword N = (std::min)(A_n_cols, B_n_rows);
  273. for(uword k=0; k < N; ++k)
  274. {
  275. eT acc = eT(0);
  276. const eT* A_colptr = A.colptr(k);
  277. // condition: A_n_rows = B_n_cols
  278. for(uword i=0; i < A_n_rows; ++i)
  279. {
  280. acc += A_colptr[i] * B.at(k,i);
  281. }
  282. out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc);
  283. }
  284. }
  285. if(is_alias) { actual_out.steal_mem(tmp); }
  286. }
  287. template<typename T1, typename T2>
  288. inline
  289. void
  290. op_diagmat::apply_times(Mat<typename T1::elem_type>& actual_out, const T1& X, const T2& Y, const typename arma_cx_only<typename T1::elem_type>::result* junk)
  291. {
  292. arma_extra_debug_sigprint();
  293. arma_ignore(junk);
  294. typedef typename T1::pod_type T;
  295. typedef typename T1::elem_type eT;
  296. const partial_unwrap<T1> UA(X);
  297. const partial_unwrap<T2> UB(Y);
  298. const typename partial_unwrap<T1>::stored_type& A = UA.M;
  299. const typename partial_unwrap<T2>::stored_type& B = UB.M;
  300. arma_debug_assert_trans_mul_size< partial_unwrap<T1>::do_trans, partial_unwrap<T2>::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
  301. const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
  302. const eT alpha = use_alpha ? (UA.get_val() * UB.get_val()) : eT(0);
  303. const uword A_n_rows = A.n_rows;
  304. const uword A_n_cols = A.n_cols;
  305. const uword B_n_rows = B.n_rows;
  306. const uword B_n_cols = B.n_cols;
  307. // check if the multiplication results in a vector
  308. if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == false) )
  309. {
  310. if((A_n_rows == 1) || (B_n_cols == 1))
  311. {
  312. arma_extra_debug_print("trans_A = false; trans_B = false; vector result");
  313. const Mat<eT> C = A*B;
  314. const eT* C_mem = C.memptr();
  315. const uword N = C.n_elem;
  316. actual_out.zeros(N,N);
  317. for(uword i=0; i<N; ++i) { actual_out.at(i,i) = (use_alpha) ? eT(alpha * C_mem[i]) : eT(C_mem[i]); }
  318. return;
  319. }
  320. }
  321. else
  322. if( (partial_unwrap<T1>::do_trans == true ) && (partial_unwrap<T2>::do_trans == false) )
  323. {
  324. if((A_n_cols == 1) || (B_n_cols == 1))
  325. {
  326. arma_extra_debug_print("trans_A = true; trans_B = false; vector result");
  327. const Mat<eT> C = trans(A)*B;
  328. const eT* C_mem = C.memptr();
  329. const uword N = C.n_elem;
  330. actual_out.zeros(N,N);
  331. for(uword i=0; i<N; ++i) { actual_out.at(i,i) = (use_alpha) ? eT(alpha * C_mem[i]) : eT(C_mem[i]); }
  332. return;
  333. }
  334. }
  335. else
  336. if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == true ) )
  337. {
  338. if((A_n_rows == 1) || (B_n_rows == 1))
  339. {
  340. arma_extra_debug_print("trans_A = false; trans_B = true; vector result");
  341. const Mat<eT> C = A*trans(B);
  342. const eT* C_mem = C.memptr();
  343. const uword N = C.n_elem;
  344. actual_out.zeros(N,N);
  345. for(uword i=0; i<N; ++i) { actual_out.at(i,i) = (use_alpha) ? eT(alpha * C_mem[i]) : eT(C_mem[i]); }
  346. return;
  347. }
  348. }
  349. else
  350. if( (partial_unwrap<T1>::do_trans == true ) && (partial_unwrap<T2>::do_trans == true ) )
  351. {
  352. if((A_n_cols == 1) || (B_n_rows == 1))
  353. {
  354. arma_extra_debug_print("trans_A = true; trans_B = true; vector result");
  355. const Mat<eT> C = trans(A)*trans(B);
  356. const eT* C_mem = C.memptr();
  357. const uword N = C.n_elem;
  358. actual_out.zeros(N,N);
  359. for(uword i=0; i<N; ++i) { actual_out.at(i,i) = (use_alpha) ? eT(alpha * C_mem[i]) : eT(C_mem[i]); }
  360. return;
  361. }
  362. }
  363. // if we got to this point, the multiplication results in a matrix
  364. const bool is_alias = (UA.is_alias(actual_out) || UB.is_alias(actual_out));
  365. Mat<eT> tmp;
  366. Mat<eT>& out = (is_alias) ? tmp : actual_out;
  367. if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == false) )
  368. {
  369. arma_extra_debug_print("trans_A = false; trans_B = false; matrix result");
  370. out.zeros(A_n_rows, B_n_cols);
  371. const uword N = (std::min)(A_n_rows, B_n_cols);
  372. for(uword k=0; k < N; ++k)
  373. {
  374. T acc_real = T(0);
  375. T acc_imag = T(0);
  376. const eT* B_colptr = B.colptr(k);
  377. // condition: A_n_cols = B_n_rows
  378. for(uword i=0; i < A_n_cols; ++i)
  379. {
  380. // acc += A.at(k, i) * B_colptr[i];
  381. const std::complex<T>& xx = A.at(k, i);
  382. const std::complex<T>& yy = B_colptr[i];
  383. const T a = xx.real();
  384. const T b = xx.imag();
  385. const T c = yy.real();
  386. const T d = yy.imag();
  387. acc_real += (a*c) - (b*d);
  388. acc_imag += (a*d) + (b*c);
  389. }
  390. const eT acc = std::complex<T>(acc_real, acc_imag);
  391. out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc);
  392. }
  393. }
  394. else
  395. if( (partial_unwrap<T1>::do_trans == true) && (partial_unwrap<T2>::do_trans == false) )
  396. {
  397. arma_extra_debug_print("trans_A = true; trans_B = false; matrix result");
  398. out.zeros(A_n_cols, B_n_cols);
  399. const uword N = (std::min)(A_n_cols, B_n_cols);
  400. for(uword k=0; k < N; ++k)
  401. {
  402. T acc_real = T(0);
  403. T acc_imag = T(0);
  404. const eT* A_colptr = A.colptr(k);
  405. const eT* B_colptr = B.colptr(k);
  406. // condition: A_n_rows = B_n_rows
  407. for(uword i=0; i < A_n_rows; ++i)
  408. {
  409. // acc += std::conj(A_colptr[i]) * B_colptr[i];
  410. const std::complex<T>& xx = A_colptr[i];
  411. const std::complex<T>& yy = B_colptr[i];
  412. const T a = xx.real();
  413. const T b = xx.imag();
  414. const T c = yy.real();
  415. const T d = yy.imag();
  416. // take into account the complex conjugate of xx
  417. acc_real += (a*c) + (b*d);
  418. acc_imag += (a*d) - (b*c);
  419. }
  420. const eT acc = std::complex<T>(acc_real, acc_imag);
  421. out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc);
  422. }
  423. }
  424. else
  425. if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == true) )
  426. {
  427. arma_extra_debug_print("trans_A = false; trans_B = true; matrix result");
  428. out.zeros(A_n_rows, B_n_rows);
  429. const uword N = (std::min)(A_n_rows, B_n_rows);
  430. for(uword k=0; k < N; ++k)
  431. {
  432. T acc_real = T(0);
  433. T acc_imag = T(0);
  434. // condition: A_n_cols = B_n_cols
  435. for(uword i=0; i < A_n_cols; ++i)
  436. {
  437. // acc += A.at(k,i) * std::conj(B.at(k,i));
  438. const std::complex<T>& xx = A.at(k, i);
  439. const std::complex<T>& yy = B.at(k, i);
  440. const T a = xx.real();
  441. const T b = xx.imag();
  442. const T c = yy.real();
  443. const T d = -yy.imag(); // take the conjugate
  444. acc_real += (a*c) - (b*d);
  445. acc_imag += (a*d) + (b*c);
  446. }
  447. const eT acc = std::complex<T>(acc_real, acc_imag);
  448. out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc);
  449. }
  450. }
  451. else
  452. if( (partial_unwrap<T1>::do_trans == true) && (partial_unwrap<T2>::do_trans == true) )
  453. {
  454. arma_extra_debug_print("trans_A = true; trans_B = true; matrix result");
  455. out.zeros(A_n_cols, B_n_rows);
  456. const uword N = (std::min)(A_n_cols, B_n_rows);
  457. for(uword k=0; k < N; ++k)
  458. {
  459. T acc_real = T(0);
  460. T acc_imag = T(0);
  461. const eT* A_colptr = A.colptr(k);
  462. // condition: A_n_rows = B_n_cols
  463. for(uword i=0; i < A_n_rows; ++i)
  464. {
  465. // acc += std::conj(A_colptr[i]) * std::conj(B.at(k,i));
  466. const std::complex<T>& xx = A_colptr[i];
  467. const std::complex<T>& yy = B.at(k, i);
  468. const T a = xx.real();
  469. const T b = -xx.imag(); // take the conjugate
  470. const T c = yy.real();
  471. const T d = -yy.imag(); // take the conjugate
  472. acc_real += (a*c) - (b*d);
  473. acc_imag += (a*d) + (b*c);
  474. }
  475. const eT acc = std::complex<T>(acc_real, acc_imag);
  476. out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc);
  477. }
  478. }
  479. if(is_alias) { actual_out.steal_mem(tmp); }
  480. }
  481. //
  482. //
  483. //
  484. template<typename T1>
  485. inline
  486. void
  487. op_diagmat2::apply(Mat<typename T1::elem_type>& out, const Op<T1, op_diagmat2>& X)
  488. {
  489. arma_extra_debug_sigprint();
  490. typedef typename T1::elem_type eT;
  491. const uword row_offset = X.aux_uword_a;
  492. const uword col_offset = X.aux_uword_b;
  493. const Proxy<T1> P(X.m);
  494. if(P.is_alias(out))
  495. {
  496. Mat<eT> tmp;
  497. op_diagmat2::apply(tmp, P, row_offset, col_offset);
  498. out.steal_mem(tmp);
  499. }
  500. else
  501. {
  502. op_diagmat2::apply(out, P, row_offset, col_offset);
  503. }
  504. }
  505. template<typename T1>
  506. inline
  507. void
  508. op_diagmat2::apply(Mat<typename T1::elem_type>& out, const Proxy<T1>& P, const uword row_offset, const uword col_offset)
  509. {
  510. arma_extra_debug_sigprint();
  511. const uword n_rows = P.get_n_rows();
  512. const uword n_cols = P.get_n_cols();
  513. const uword n_elem = P.get_n_elem();
  514. if(n_elem == 0) { out.reset(); return; }
  515. const bool P_is_vec = (T1::is_row) || (T1::is_col) || (n_rows == 1) || (n_cols == 1);
  516. if(P_is_vec)
  517. {
  518. const uword n_pad = (std::max)(row_offset, col_offset);
  519. out.zeros(n_elem + n_pad, n_elem + n_pad);
  520. if(Proxy<T1>::use_at == false)
  521. {
  522. typename Proxy<T1>::ea_type Pea = P.get_ea();
  523. for(uword i=0; i < n_elem; ++i) { out.at(row_offset + i, col_offset + i) = Pea[i]; }
  524. }
  525. else
  526. {
  527. if(n_rows == 1)
  528. {
  529. for(uword i=0; i < n_elem; ++i) { out.at(row_offset + i, col_offset + i) = P.at(0,i); }
  530. }
  531. else
  532. {
  533. for(uword i=0; i < n_elem; ++i) { out.at(row_offset + i, col_offset + i) = P.at(i,0); }
  534. }
  535. }
  536. }
  537. else // P represents a matrix
  538. {
  539. arma_debug_check
  540. (
  541. ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)),
  542. "diagmat(): requested diagonal out of bounds"
  543. );
  544. out.zeros(n_rows, n_cols);
  545. const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset);
  546. for(uword i=0; i<N; ++i)
  547. {
  548. const uword row = i + row_offset;
  549. const uword col = i + col_offset;
  550. out.at(row,col) = P.at(row,col);
  551. }
  552. }
  553. }
  554. //! @}