spdiagview_meat.hpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup spdiagview
  16. //! @{
  17. template<typename eT>
  18. inline
  19. spdiagview<eT>::~spdiagview()
  20. {
  21. arma_extra_debug_sigprint();
  22. }
  23. template<typename eT>
  24. arma_inline
  25. spdiagview<eT>::spdiagview(const SpMat<eT>& in_m, const uword in_row_offset, const uword in_col_offset, const uword in_len)
  26. : m(in_m)
  27. , row_offset(in_row_offset)
  28. , col_offset(in_col_offset)
  29. , n_rows(in_len)
  30. , n_elem(in_len)
  31. {
  32. arma_extra_debug_sigprint();
  33. }
  34. //! set a diagonal of our matrix using a diagonal from a foreign matrix
  35. template<typename eT>
  36. inline
  37. void
  38. spdiagview<eT>::operator= (const spdiagview<eT>& x)
  39. {
  40. arma_extra_debug_sigprint();
  41. spdiagview<eT>& d = *this;
  42. arma_debug_check( (d.n_elem != x.n_elem), "spdiagview: diagonals have incompatible lengths" );
  43. SpMat<eT>& d_m = const_cast< SpMat<eT>& >(d.m);
  44. const SpMat<eT>& x_m = x.m;
  45. if( (&d_m == &x_m) || ((d.row_offset == 0) && (d.col_offset == 0)) )
  46. {
  47. const Mat<eT> tmp(x);
  48. (*this).operator=(tmp);
  49. }
  50. else
  51. {
  52. const uword d_n_elem = d.n_elem;
  53. const uword d_row_offset = d.row_offset;
  54. const uword d_col_offset = d.col_offset;
  55. const uword x_row_offset = x.row_offset;
  56. const uword x_col_offset = x.col_offset;
  57. for(uword i=0; i < d_n_elem; ++i)
  58. {
  59. d_m.at(i + d_row_offset, i + d_col_offset) = x_m.at(i + x_row_offset, i + x_col_offset);
  60. }
  61. }
  62. }
  63. template<typename eT>
  64. inline
  65. void
  66. spdiagview<eT>::operator+=(const eT val)
  67. {
  68. arma_extra_debug_sigprint();
  69. if(val == eT(0)) { return; }
  70. SpMat<eT>& t_m = const_cast< SpMat<eT>& >(m);
  71. const uword t_n_elem = n_elem;
  72. const uword t_row_offset = row_offset;
  73. const uword t_col_offset = col_offset;
  74. for(uword i=0; i < t_n_elem; ++i)
  75. {
  76. t_m.at(i + t_row_offset, i + t_col_offset) += val;
  77. }
  78. }
  79. template<typename eT>
  80. inline
  81. void
  82. spdiagview<eT>::operator-=(const eT val)
  83. {
  84. arma_extra_debug_sigprint();
  85. if(val == eT(0)) { return; }
  86. SpMat<eT>& t_m = const_cast< SpMat<eT>& >(m);
  87. const uword t_n_elem = n_elem;
  88. const uword t_row_offset = row_offset;
  89. const uword t_col_offset = col_offset;
  90. for(uword i=0; i < t_n_elem; ++i)
  91. {
  92. t_m.at(i + t_row_offset, i + t_col_offset) -= val;
  93. }
  94. }
  95. template<typename eT>
  96. inline
  97. void
  98. spdiagview<eT>::operator*=(const eT val)
  99. {
  100. arma_extra_debug_sigprint();
  101. if(val == eT(0)) { (*this).zeros(); return; }
  102. SpMat<eT>& t_m = const_cast< SpMat<eT>& >(m);
  103. const uword t_n_elem = n_elem;
  104. const uword t_row_offset = row_offset;
  105. const uword t_col_offset = col_offset;
  106. for(uword i=0; i < t_n_elem; ++i)
  107. {
  108. t_m.at(i + t_row_offset, i + t_col_offset) *= val;
  109. }
  110. }
  111. template<typename eT>
  112. inline
  113. void
  114. spdiagview<eT>::operator/=(const eT val)
  115. {
  116. arma_extra_debug_sigprint();
  117. SpMat<eT>& t_m = const_cast< SpMat<eT>& >(m);
  118. const uword t_n_elem = n_elem;
  119. const uword t_row_offset = row_offset;
  120. const uword t_col_offset = col_offset;
  121. for(uword i=0; i < t_n_elem; ++i)
  122. {
  123. t_m.at(i + t_row_offset, i + t_col_offset) /= val;
  124. }
  125. }
  126. //! set a diagonal of our matrix using data from a foreign object
  127. template<typename eT>
  128. template<typename T1>
  129. inline
  130. void
  131. spdiagview<eT>::operator= (const Base<eT,T1>& o)
  132. {
  133. arma_extra_debug_sigprint();
  134. spdiagview<eT>& d = *this;
  135. SpMat<eT>& d_m = const_cast< SpMat<eT>& >(d.m);
  136. const uword d_n_elem = d.n_elem;
  137. const uword d_row_offset = d.row_offset;
  138. const uword d_col_offset = d.col_offset;
  139. const quasi_unwrap<T1> U(o.get_ref());
  140. const Mat<eT>& x = U.M;
  141. const eT* x_mem = x.memptr();
  142. arma_debug_check
  143. (
  144. ( (d_n_elem != x.n_elem) || ((x.n_rows != 1) && (x.n_cols != 1)) ),
  145. "spdiagview: given object has incompatible size"
  146. );
  147. if( (d_row_offset == 0) && (d_col_offset == 0) )
  148. {
  149. SpMat<eT> tmp1;
  150. tmp1.eye(d_m.n_rows, d_m.n_cols);
  151. bool has_zero = false;
  152. for(uword i=0; i < d_n_elem; ++i)
  153. {
  154. const eT val = x_mem[i];
  155. access::rw(tmp1.values[i]) = val;
  156. if(val == eT(0)) { has_zero = true; }
  157. }
  158. if(has_zero) { tmp1.remove_zeros(); }
  159. SpMat<eT> tmp2;
  160. spglue_merge::diagview_merge(tmp2, d_m, tmp1);
  161. d_m.steal_mem(tmp2);
  162. }
  163. else
  164. {
  165. for(uword i=0; i < d_n_elem; ++i)
  166. {
  167. d_m.at(i + d_row_offset, i + d_col_offset) = x_mem[i];
  168. }
  169. }
  170. }
  171. template<typename eT>
  172. template<typename T1>
  173. inline
  174. void
  175. spdiagview<eT>::operator+=(const Base<eT,T1>& o)
  176. {
  177. arma_extra_debug_sigprint();
  178. spdiagview<eT>& d = *this;
  179. SpMat<eT>& d_m = const_cast< SpMat<eT>& >(d.m);
  180. const uword d_n_elem = d.n_elem;
  181. const uword d_row_offset = d.row_offset;
  182. const uword d_col_offset = d.col_offset;
  183. const Proxy<T1> P( o.get_ref() );
  184. arma_debug_check
  185. (
  186. ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ),
  187. "spdiagview: given object has incompatible size"
  188. );
  189. if( (is_Mat<typename Proxy<T1>::stored_type>::value) || (Proxy<T1>::use_at) )
  190. {
  191. const unwrap<typename Proxy<T1>::stored_type> tmp(P.Q);
  192. const Mat<eT>& x = tmp.M;
  193. const eT* x_mem = x.memptr();
  194. for(uword i=0; i < d_n_elem; ++i)
  195. {
  196. d_m.at(i + d_row_offset, i + d_col_offset) += x_mem[i];
  197. }
  198. }
  199. else
  200. {
  201. typename Proxy<T1>::ea_type Pea = P.get_ea();
  202. for(uword i=0; i < d_n_elem; ++i)
  203. {
  204. d_m.at(i + d_row_offset, i + d_col_offset) += Pea[i];
  205. }
  206. }
  207. }
  208. template<typename eT>
  209. template<typename T1>
  210. inline
  211. void
  212. spdiagview<eT>::operator-=(const Base<eT,T1>& o)
  213. {
  214. arma_extra_debug_sigprint();
  215. spdiagview<eT>& d = *this;
  216. SpMat<eT>& d_m = const_cast< SpMat<eT>& >(d.m);
  217. const uword d_n_elem = d.n_elem;
  218. const uword d_row_offset = d.row_offset;
  219. const uword d_col_offset = d.col_offset;
  220. const Proxy<T1> P( o.get_ref() );
  221. arma_debug_check
  222. (
  223. ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ),
  224. "spdiagview: given object has incompatible size"
  225. );
  226. if( (is_Mat<typename Proxy<T1>::stored_type>::value) || (Proxy<T1>::use_at) )
  227. {
  228. const unwrap<typename Proxy<T1>::stored_type> tmp(P.Q);
  229. const Mat<eT>& x = tmp.M;
  230. const eT* x_mem = x.memptr();
  231. for(uword i=0; i < d_n_elem; ++i)
  232. {
  233. d_m.at(i + d_row_offset, i + d_col_offset) -= x_mem[i];
  234. }
  235. }
  236. else
  237. {
  238. typename Proxy<T1>::ea_type Pea = P.get_ea();
  239. for(uword i=0; i < d_n_elem; ++i)
  240. {
  241. d_m.at(i + d_row_offset, i + d_col_offset) -= Pea[i];
  242. }
  243. }
  244. }
  245. template<typename eT>
  246. template<typename T1>
  247. inline
  248. void
  249. spdiagview<eT>::operator%=(const Base<eT,T1>& o)
  250. {
  251. arma_extra_debug_sigprint();
  252. spdiagview<eT>& d = *this;
  253. SpMat<eT>& d_m = const_cast< SpMat<eT>& >(d.m);
  254. const uword d_n_elem = d.n_elem;
  255. const uword d_row_offset = d.row_offset;
  256. const uword d_col_offset = d.col_offset;
  257. const Proxy<T1> P( o.get_ref() );
  258. arma_debug_check
  259. (
  260. ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ),
  261. "spdiagview: given object has incompatible size"
  262. );
  263. if( (is_Mat<typename Proxy<T1>::stored_type>::value) || (Proxy<T1>::use_at) )
  264. {
  265. const unwrap<typename Proxy<T1>::stored_type> tmp(P.Q);
  266. const Mat<eT>& x = tmp.M;
  267. const eT* x_mem = x.memptr();
  268. for(uword i=0; i < d_n_elem; ++i)
  269. {
  270. d_m.at(i + d_row_offset, i + d_col_offset) *= x_mem[i];
  271. }
  272. }
  273. else
  274. {
  275. typename Proxy<T1>::ea_type Pea = P.get_ea();
  276. for(uword i=0; i < d_n_elem; ++i)
  277. {
  278. d_m.at(i + d_row_offset, i + d_col_offset) *= Pea[i];
  279. }
  280. }
  281. }
  282. template<typename eT>
  283. template<typename T1>
  284. inline
  285. void
  286. spdiagview<eT>::operator/=(const Base<eT,T1>& o)
  287. {
  288. arma_extra_debug_sigprint();
  289. spdiagview<eT>& d = *this;
  290. SpMat<eT>& d_m = const_cast< SpMat<eT>& >(d.m);
  291. const uword d_n_elem = d.n_elem;
  292. const uword d_row_offset = d.row_offset;
  293. const uword d_col_offset = d.col_offset;
  294. const Proxy<T1> P( o.get_ref() );
  295. arma_debug_check
  296. (
  297. ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ),
  298. "spdiagview: given object has incompatible size"
  299. );
  300. if( (is_Mat<typename Proxy<T1>::stored_type>::value) || (Proxy<T1>::use_at) )
  301. {
  302. const unwrap<typename Proxy<T1>::stored_type> tmp(P.Q);
  303. const Mat<eT>& x = tmp.M;
  304. const eT* x_mem = x.memptr();
  305. for(uword i=0; i < d_n_elem; ++i)
  306. {
  307. d_m.at(i + d_row_offset, i + d_col_offset) /= x_mem[i];
  308. }
  309. }
  310. else
  311. {
  312. typename Proxy<T1>::ea_type Pea = P.get_ea();
  313. for(uword i=0; i < d_n_elem; ++i)
  314. {
  315. d_m.at(i + d_row_offset, i + d_col_offset) /= Pea[i];
  316. }
  317. }
  318. }
  319. //! set a diagonal of our matrix using data from a foreign object
  320. template<typename eT>
  321. template<typename T1>
  322. inline
  323. void
  324. spdiagview<eT>::operator= (const SpBase<eT,T1>& o)
  325. {
  326. arma_extra_debug_sigprint();
  327. const unwrap_spmat<T1> U( o.get_ref() );
  328. const SpMat<eT>& x = U.M;
  329. arma_debug_check
  330. (
  331. ( (n_elem != x.n_elem) || ((x.n_rows != 1) && (x.n_cols != 1)) ),
  332. "spdiagview: given object has incompatible size"
  333. );
  334. const Mat<eT> tmp(x);
  335. (*this).operator=(tmp);
  336. }
  337. template<typename eT>
  338. template<typename T1>
  339. inline
  340. void
  341. spdiagview<eT>::operator+=(const SpBase<eT,T1>& o)
  342. {
  343. arma_extra_debug_sigprint();
  344. spdiagview<eT>& d = *this;
  345. SpMat<eT>& d_m = const_cast< SpMat<eT>& >(d.m);
  346. const uword d_n_elem = d.n_elem;
  347. const uword d_row_offset = d.row_offset;
  348. const uword d_col_offset = d.col_offset;
  349. const SpProxy<T1> P( o.get_ref() );
  350. arma_debug_check
  351. (
  352. ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ),
  353. "spdiagview: given object has incompatible size"
  354. );
  355. if( SpProxy<T1>::use_iterator || P.is_alias(d_m) )
  356. {
  357. const SpMat<eT> tmp(P.Q);
  358. if(tmp.n_cols == 1)
  359. {
  360. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) += tmp.at(i,0); }
  361. }
  362. else
  363. if(tmp.n_rows == 1)
  364. {
  365. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) += tmp.at(0,i); }
  366. }
  367. }
  368. else
  369. {
  370. if(P.get_n_cols() == 1)
  371. {
  372. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) += P.at(i,0); }
  373. }
  374. else
  375. if(P.get_n_rows() == 1)
  376. {
  377. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) += P.at(0,i); }
  378. }
  379. }
  380. }
  381. template<typename eT>
  382. template<typename T1>
  383. inline
  384. void
  385. spdiagview<eT>::operator-=(const SpBase<eT,T1>& o)
  386. {
  387. arma_extra_debug_sigprint();
  388. spdiagview<eT>& d = *this;
  389. SpMat<eT>& d_m = const_cast< SpMat<eT>& >(d.m);
  390. const uword d_n_elem = d.n_elem;
  391. const uword d_row_offset = d.row_offset;
  392. const uword d_col_offset = d.col_offset;
  393. const SpProxy<T1> P( o.get_ref() );
  394. arma_debug_check
  395. (
  396. ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ),
  397. "spdiagview: given object has incompatible size"
  398. );
  399. if( SpProxy<T1>::use_iterator || P.is_alias(d_m) )
  400. {
  401. const SpMat<eT> tmp(P.Q);
  402. if(tmp.n_cols == 1)
  403. {
  404. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) -= tmp.at(i,0); }
  405. }
  406. else
  407. if(tmp.n_rows == 1)
  408. {
  409. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) -= tmp.at(0,i); }
  410. }
  411. }
  412. else
  413. {
  414. if(P.get_n_cols() == 1)
  415. {
  416. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) -= P.at(i,0); }
  417. }
  418. else
  419. if(P.get_n_rows() == 1)
  420. {
  421. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) -= P.at(0,i); }
  422. }
  423. }
  424. }
  425. template<typename eT>
  426. template<typename T1>
  427. inline
  428. void
  429. spdiagview<eT>::operator%=(const SpBase<eT,T1>& o)
  430. {
  431. arma_extra_debug_sigprint();
  432. spdiagview<eT>& d = *this;
  433. SpMat<eT>& d_m = const_cast< SpMat<eT>& >(d.m);
  434. const uword d_n_elem = d.n_elem;
  435. const uword d_row_offset = d.row_offset;
  436. const uword d_col_offset = d.col_offset;
  437. const SpProxy<T1> P( o.get_ref() );
  438. arma_debug_check
  439. (
  440. ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ),
  441. "spdiagview: given object has incompatible size"
  442. );
  443. if( SpProxy<T1>::use_iterator || P.is_alias(d_m) )
  444. {
  445. const SpMat<eT> tmp(P.Q);
  446. if(tmp.n_cols == 1)
  447. {
  448. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) *= tmp.at(i,0); }
  449. }
  450. else
  451. if(tmp.n_rows == 1)
  452. {
  453. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) *= tmp.at(0,i); }
  454. }
  455. }
  456. else
  457. {
  458. if(P.get_n_cols() == 1)
  459. {
  460. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) *= P.at(i,0); }
  461. }
  462. else
  463. if(P.get_n_rows() == 1)
  464. {
  465. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) *= P.at(0,i); }
  466. }
  467. }
  468. }
  469. template<typename eT>
  470. template<typename T1>
  471. inline
  472. void
  473. spdiagview<eT>::operator/=(const SpBase<eT,T1>& o)
  474. {
  475. arma_extra_debug_sigprint();
  476. spdiagview<eT>& d = *this;
  477. SpMat<eT>& d_m = const_cast< SpMat<eT>& >(d.m);
  478. const uword d_n_elem = d.n_elem;
  479. const uword d_row_offset = d.row_offset;
  480. const uword d_col_offset = d.col_offset;
  481. const SpProxy<T1> P( o.get_ref() );
  482. arma_debug_check
  483. (
  484. ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ),
  485. "spdiagview: given object has incompatible size"
  486. );
  487. if( SpProxy<T1>::use_iterator || P.is_alias(d_m) )
  488. {
  489. const SpMat<eT> tmp(P.Q);
  490. if(tmp.n_cols == 1)
  491. {
  492. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) /= tmp.at(i,0); }
  493. }
  494. else
  495. if(tmp.n_rows == 1)
  496. {
  497. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) /= tmp.at(0,i); }
  498. }
  499. }
  500. else
  501. {
  502. if(P.get_n_cols() == 1)
  503. {
  504. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) /= P.at(i,0); }
  505. }
  506. else
  507. if(P.get_n_rows() == 1)
  508. {
  509. for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) /= P.at(0,i); }
  510. }
  511. }
  512. }
  513. template<typename eT>
  514. inline
  515. void
  516. spdiagview<eT>::extract(SpMat<eT>& out, const spdiagview<eT>& d)
  517. {
  518. arma_extra_debug_sigprint();
  519. const SpMat<eT>& d_m = d.m;
  520. const uword d_n_elem = d.n_elem;
  521. const uword d_row_offset = d.row_offset;
  522. const uword d_col_offset = d.col_offset;
  523. Col<eT> cache(d_n_elem);
  524. eT* cache_mem = cache.memptr();
  525. uword d_n_nonzero = 0;
  526. for(uword i=0; i < d_n_elem; ++i)
  527. {
  528. const eT val = d_m.at(i + d_row_offset, i + d_col_offset);
  529. cache_mem[i] = val;
  530. d_n_nonzero += (val != eT(0)) ? uword(1) : uword(0);
  531. }
  532. out.reserve(d_n_elem, 1, d_n_nonzero);
  533. uword count = 0;
  534. for(uword i=0; i < d_n_elem; ++i)
  535. {
  536. const eT val = cache_mem[i];
  537. if(val != eT(0))
  538. {
  539. access::rw(out.row_indices[count]) = i;
  540. access::rw(out.values[count]) = val;
  541. ++count;
  542. }
  543. }
  544. access::rw(out.col_ptrs[0]) = 0;
  545. access::rw(out.col_ptrs[1]) = d_n_nonzero;
  546. }
  547. //! extract a diagonal and store it as a dense column vector
  548. template<typename eT>
  549. inline
  550. void
  551. spdiagview<eT>::extract(Mat<eT>& out, const spdiagview<eT>& in)
  552. {
  553. arma_extra_debug_sigprint();
  554. // NOTE: we're assuming that the 'out' matrix has already been set to the correct size;
  555. // size setting is done by either the Mat contructor or Mat::operator=()
  556. const SpMat<eT>& in_m = in.m;
  557. const uword in_n_elem = in.n_elem;
  558. const uword in_row_offset = in.row_offset;
  559. const uword in_col_offset = in.col_offset;
  560. eT* out_mem = out.memptr();
  561. for(uword i=0; i < in_n_elem; ++i)
  562. {
  563. out_mem[i] = in_m.at(i + in_row_offset, i + in_col_offset);
  564. }
  565. }
  566. template<typename eT>
  567. inline
  568. SpMat_MapMat_val<eT>
  569. spdiagview<eT>::operator[](const uword i)
  570. {
  571. return (const_cast< SpMat<eT>& >(m)).at(i+row_offset, i+col_offset);
  572. }
  573. template<typename eT>
  574. inline
  575. eT
  576. spdiagview<eT>::operator[](const uword i) const
  577. {
  578. return m.at(i+row_offset, i+col_offset);
  579. }
  580. template<typename eT>
  581. inline
  582. SpMat_MapMat_val<eT>
  583. spdiagview<eT>::at(const uword i)
  584. {
  585. return (const_cast< SpMat<eT>& >(m)).at(i+row_offset, i+col_offset);
  586. }
  587. template<typename eT>
  588. inline
  589. eT
  590. spdiagview<eT>::at(const uword i) const
  591. {
  592. return m.at(i+row_offset, i+col_offset);
  593. }
  594. template<typename eT>
  595. inline
  596. SpMat_MapMat_val<eT>
  597. spdiagview<eT>::operator()(const uword i)
  598. {
  599. arma_debug_check( (i >= n_elem), "spdiagview::operator(): out of bounds" );
  600. return (const_cast< SpMat<eT>& >(m)).at(i+row_offset, i+col_offset);
  601. }
  602. template<typename eT>
  603. inline
  604. eT
  605. spdiagview<eT>::operator()(const uword i) const
  606. {
  607. arma_debug_check( (i >= n_elem), "spdiagview::operator(): out of bounds" );
  608. return m.at(i+row_offset, i+col_offset);
  609. }
  610. template<typename eT>
  611. inline
  612. SpMat_MapMat_val<eT>
  613. spdiagview<eT>::at(const uword row, const uword)
  614. {
  615. return (const_cast< SpMat<eT>& >(m)).at(row+row_offset, row+col_offset);
  616. }
  617. template<typename eT>
  618. inline
  619. eT
  620. spdiagview<eT>::at(const uword row, const uword) const
  621. {
  622. return m.at(row+row_offset, row+col_offset);
  623. }
  624. template<typename eT>
  625. inline
  626. SpMat_MapMat_val<eT>
  627. spdiagview<eT>::operator()(const uword row, const uword col)
  628. {
  629. arma_debug_check( ((row >= n_elem) || (col > 0)), "spdiagview::operator(): out of bounds" );
  630. return (const_cast< SpMat<eT>& >(m)).at(row+row_offset, row+col_offset);
  631. }
  632. template<typename eT>
  633. inline
  634. eT
  635. spdiagview<eT>::operator()(const uword row, const uword col) const
  636. {
  637. arma_debug_check( ((row >= n_elem) || (col > 0)), "spdiagview::operator(): out of bounds" );
  638. return m.at(row+row_offset, row+col_offset);
  639. }
  640. template<typename eT>
  641. inline
  642. void
  643. spdiagview<eT>::fill(const eT val)
  644. {
  645. arma_extra_debug_sigprint();
  646. if( (row_offset == 0) && (col_offset == 0) && (m.sync_state != 1) )
  647. {
  648. if(val == eT(0))
  649. {
  650. SpMat<eT> tmp(arma_reserve_indicator(), m.n_rows, m.n_cols, m.n_nonzero); // worst case scenario
  651. typename SpMat<eT>::const_iterator it = m.begin();
  652. typename SpMat<eT>::const_iterator it_end = m.end();
  653. uword count = 0;
  654. for(; it != it_end; ++it)
  655. {
  656. const uword row = it.row();
  657. const uword col = it.col();
  658. if(row != col)
  659. {
  660. access::rw(tmp.values[count]) = (*it);
  661. access::rw(tmp.row_indices[count]) = row;
  662. access::rw(tmp.col_ptrs[col + 1])++;
  663. ++count;
  664. }
  665. }
  666. for(uword i=0; i < tmp.n_cols; ++i)
  667. {
  668. access::rw(tmp.col_ptrs[i + 1]) += tmp.col_ptrs[i];
  669. }
  670. // quick resize without reallocating memory and copying data
  671. access::rw( tmp.n_nonzero) = count;
  672. access::rw( tmp.values[count]) = eT(0);
  673. access::rw(tmp.row_indices[count]) = uword(0);
  674. access::rw(m).steal_mem(tmp);
  675. }
  676. else // val != eT(0)
  677. {
  678. SpMat<eT> tmp1;
  679. tmp1.eye(m.n_rows, m.n_cols);
  680. if(val != eT(1)) { tmp1 *= val; }
  681. SpMat<eT> tmp2;
  682. spglue_merge::diagview_merge(tmp2, m, tmp1);
  683. access::rw(m).steal_mem(tmp2);
  684. }
  685. }
  686. else
  687. {
  688. SpMat<eT>& x = const_cast< SpMat<eT>& >(m);
  689. const uword local_n_elem = n_elem;
  690. for(uword i=0; i < local_n_elem; ++i)
  691. {
  692. x.at(i+row_offset, i+col_offset) = val;
  693. }
  694. }
  695. }
  696. template<typename eT>
  697. inline
  698. void
  699. spdiagview<eT>::zeros()
  700. {
  701. arma_extra_debug_sigprint();
  702. (*this).fill(eT(0));
  703. }
  704. template<typename eT>
  705. inline
  706. void
  707. spdiagview<eT>::ones()
  708. {
  709. arma_extra_debug_sigprint();
  710. (*this).fill(eT(1));
  711. }
  712. template<typename eT>
  713. inline
  714. void
  715. spdiagview<eT>::randu()
  716. {
  717. arma_extra_debug_sigprint();
  718. SpMat<eT>& x = const_cast< SpMat<eT>& >(m);
  719. const uword local_n_elem = n_elem;
  720. for(uword i=0; i < local_n_elem; ++i)
  721. {
  722. x.at(i+row_offset, i+col_offset) = eT(arma_rng::randu<eT>());
  723. }
  724. }
  725. template<typename eT>
  726. inline
  727. void
  728. spdiagview<eT>::randn()
  729. {
  730. arma_extra_debug_sigprint();
  731. SpMat<eT>& x = const_cast< SpMat<eT>& >(m);
  732. const uword local_n_elem = n_elem;
  733. for(uword i=0; i < local_n_elem; ++i)
  734. {
  735. x.at(i+row_offset, i+col_offset) = eT(arma_rng::randn<eT>());
  736. }
  737. }
  738. //! @}