SpMat_meat.hpp 147 KB


  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup SpMat
  16. //! @{
  17. /**
  18. * Initialize a sparse matrix with size 0x0 (empty).
  19. */
  20. template<typename eT>
  21. inline
  22. SpMat<eT>::SpMat()
  23. : n_rows(0)
  24. , n_cols(0)
  25. , n_elem(0)
  26. , n_nonzero(0)
  27. , vec_state(0)
  28. , values(NULL)
  29. , row_indices(NULL)
  30. , col_ptrs(NULL)
  31. {
  32. arma_extra_debug_sigprint_this(this);
  33. init_cold(0,0);
  34. }
  35. /**
  36. * Clean up the memory of a sparse matrix and destruct it.
  37. */
  38. template<typename eT>
  39. inline
  40. SpMat<eT>::~SpMat()
  41. {
  42. arma_extra_debug_sigprint_this(this);
  43. if(values ) { memory::release(access::rw(values)); }
  44. if(row_indices) { memory::release(access::rw(row_indices)); }
  45. if(col_ptrs ) { memory::release(access::rw(col_ptrs)); }
  46. }
  47. /**
  48. * Constructor with size given.
  49. */
  50. template<typename eT>
  51. inline
  52. SpMat<eT>::SpMat(const uword in_rows, const uword in_cols)
  53. : n_rows(0)
  54. , n_cols(0)
  55. , n_elem(0)
  56. , n_nonzero(0)
  57. , vec_state(0)
  58. , values(NULL)
  59. , row_indices(NULL)
  60. , col_ptrs(NULL)
  61. {
  62. arma_extra_debug_sigprint_this(this);
  63. init_cold(in_rows, in_cols);
  64. }
  65. template<typename eT>
  66. inline
  67. SpMat<eT>::SpMat(const SizeMat& s)
  68. : n_rows(0)
  69. , n_cols(0)
  70. , n_elem(0)
  71. , n_nonzero(0)
  72. , vec_state(0)
  73. , values(NULL)
  74. , row_indices(NULL)
  75. , col_ptrs(NULL)
  76. {
  77. arma_extra_debug_sigprint_this(this);
  78. init_cold(s.n_rows, s.n_cols);
  79. }
  80. template<typename eT>
  81. inline
  82. SpMat<eT>::SpMat(const arma_reserve_indicator&, const uword in_rows, const uword in_cols, const uword new_n_nonzero)
  83. : n_rows(0)
  84. , n_cols(0)
  85. , n_elem(0)
  86. , n_nonzero(0)
  87. , vec_state(0)
  88. , values(NULL)
  89. , row_indices(NULL)
  90. , col_ptrs(NULL)
  91. {
  92. arma_extra_debug_sigprint_this(this);
  93. init_cold(in_rows, in_cols, new_n_nonzero);
  94. }
  95. template<typename eT>
  96. template<typename eT2>
  97. inline
  98. SpMat<eT>::SpMat(const arma_layout_indicator&, const SpMat<eT2>& x)
  99. : n_rows(0)
  100. , n_cols(0)
  101. , n_elem(0)
  102. , n_nonzero(0)
  103. , vec_state(0)
  104. , values(NULL)
  105. , row_indices(NULL)
  106. , col_ptrs(NULL)
  107. {
  108. arma_extra_debug_sigprint_this(this);
  109. init_cold(x.n_rows, x.n_cols, x.n_nonzero);
  110. if(x.n_nonzero == 0) { return; }
  111. if(x.row_indices) { arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1); }
  112. if(x.col_ptrs ) { arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1); }
  113. // NOTE: 'values' array is not initialised
  114. }
  115. /**
  116. * Assemble from text.
  117. */
  118. template<typename eT>
  119. inline
  120. SpMat<eT>::SpMat(const char* text)
  121. : n_rows(0)
  122. , n_cols(0)
  123. , n_elem(0)
  124. , n_nonzero(0)
  125. , vec_state(0)
  126. , values(NULL)
  127. , row_indices(NULL)
  128. , col_ptrs(NULL)
  129. {
  130. arma_extra_debug_sigprint_this(this);
  131. init(std::string(text));
  132. }
  133. template<typename eT>
  134. inline
  135. SpMat<eT>&
  136. SpMat<eT>::operator=(const char* text)
  137. {
  138. arma_extra_debug_sigprint();
  139. init(std::string(text));
  140. return *this;
  141. }
  142. template<typename eT>
  143. inline
  144. SpMat<eT>::SpMat(const std::string& text)
  145. : n_rows(0)
  146. , n_cols(0)
  147. , n_elem(0)
  148. , n_nonzero(0)
  149. , vec_state(0)
  150. , values(NULL)
  151. , row_indices(NULL)
  152. , col_ptrs(NULL)
  153. {
  154. arma_extra_debug_sigprint();
  155. init(text);
  156. }
  157. template<typename eT>
  158. inline
  159. SpMat<eT>&
  160. SpMat<eT>::operator=(const std::string& text)
  161. {
  162. arma_extra_debug_sigprint();
  163. init(text);
  164. return *this;
  165. }
  166. template<typename eT>
  167. inline
  168. SpMat<eT>::SpMat(const SpMat<eT>& x)
  169. : n_rows(0)
  170. , n_cols(0)
  171. , n_elem(0)
  172. , n_nonzero(0)
  173. , vec_state(0)
  174. , values(NULL)
  175. , row_indices(NULL)
  176. , col_ptrs(NULL)
  177. {
  178. arma_extra_debug_sigprint_this(this);
  179. init(x);
  180. }
  181. #if defined(ARMA_USE_CXX11)
  182. template<typename eT>
  183. inline
  184. SpMat<eT>::SpMat(SpMat<eT>&& in_mat)
  185. : n_rows(0)
  186. , n_cols(0)
  187. , n_elem(0)
  188. , n_nonzero(0)
  189. , vec_state(0)
  190. , values(NULL)
  191. , row_indices(NULL)
  192. , col_ptrs(NULL)
  193. {
  194. arma_extra_debug_sigprint_this(this);
  195. arma_extra_debug_sigprint(arma_str::format("this = %x in_mat = %x") % this % &in_mat);
  196. (*this).steal_mem(in_mat);
  197. }
  198. template<typename eT>
  199. inline
  200. SpMat<eT>&
  201. SpMat<eT>::operator=(SpMat<eT>&& in_mat)
  202. {
  203. arma_extra_debug_sigprint(arma_str::format("this = %x in_mat = %x") % this % &in_mat);
  204. (*this).steal_mem(in_mat);
  205. return *this;
  206. }
  207. #endif
  208. template<typename eT>
  209. inline
  210. SpMat<eT>::SpMat(const MapMat<eT>& x)
  211. : n_rows(0)
  212. , n_cols(0)
  213. , n_elem(0)
  214. , n_nonzero(0)
  215. , vec_state(0)
  216. , values(NULL)
  217. , row_indices(NULL)
  218. , col_ptrs(NULL)
  219. {
  220. arma_extra_debug_sigprint_this(this);
  221. init(x);
  222. }
  223. template<typename eT>
  224. inline
  225. SpMat<eT>&
  226. SpMat<eT>::operator=(const MapMat<eT>& x)
  227. {
  228. arma_extra_debug_sigprint();
  229. init(x);
  230. return *this;
  231. }
  232. //! Insert a large number of values at once.
  233. //! locations.row[0] should be row indices, locations.row[1] should be column indices,
  234. //! and values should be the corresponding values.
  235. //! If sort_locations is false, then it is assumed that the locations and values
  236. //! are already sorted in column-major ordering.
  237. template<typename eT>
  238. template<typename T1, typename T2>
  239. inline
  240. SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_expr, const bool sort_locations)
  241. : n_rows(0)
  242. , n_cols(0)
  243. , n_elem(0)
  244. , n_nonzero(0)
  245. , vec_state(0)
  246. , values(NULL)
  247. , row_indices(NULL)
  248. , col_ptrs(NULL)
  249. {
  250. arma_extra_debug_sigprint_this(this);
  251. const unwrap<T1> locs_tmp( locations_expr.get_ref() );
  252. const unwrap<T2> vals_tmp( vals_expr.get_ref() );
  253. const Mat<uword>& locs = locs_tmp.M;
  254. const Mat<eT>& vals = vals_tmp.M;
  255. arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" );
  256. arma_debug_check( (locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows" );
  257. arma_debug_check( (locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values" );
  258. // If there are no elements in the list, max() will fail.
  259. if(locs.n_cols == 0) { init_cold(0, 0); return; }
  260. // Automatically determine size before pruning zeros.
  261. uvec bounds = arma::max(locs, 1);
  262. init_cold(bounds[0] + 1, bounds[1] + 1);
  263. // Ensure that there are no zeros
  264. const uword N_old = vals.n_elem;
  265. uword N_new = 0;
  266. for(uword i = 0; i < N_old; ++i)
  267. {
  268. if(vals[i] != eT(0)) { ++N_new; }
  269. }
  270. if(N_new != N_old)
  271. {
  272. Col<eT> filtered_vals(N_new);
  273. Mat<uword> filtered_locs(2, N_new);
  274. uword index = 0;
  275. for(uword i = 0; i < N_old; ++i)
  276. {
  277. if(vals[i] != eT(0))
  278. {
  279. filtered_vals[index] = vals[i];
  280. filtered_locs.at(0, index) = locs.at(0, i);
  281. filtered_locs.at(1, index) = locs.at(1, i);
  282. ++index;
  283. }
  284. }
  285. init_batch_std(filtered_locs, filtered_vals, sort_locations);
  286. }
  287. else
  288. {
  289. init_batch_std(locs, vals, sort_locations);
  290. }
  291. }
  292. //! Insert a large number of values at once.
  293. //! locations.row[0] should be row indices, locations.row[1] should be column indices,
  294. //! and values should be the corresponding values.
  295. //! If sort_locations is false, then it is assumed that the locations and values
  296. //! are already sorted in column-major ordering.
  297. //! In this constructor the size is explicitly given.
  298. template<typename eT>
  299. template<typename T1, typename T2>
  300. inline
  301. SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_expr, const uword in_n_rows, const uword in_n_cols, const bool sort_locations, const bool check_for_zeros)
  302. : n_rows(0)
  303. , n_cols(0)
  304. , n_elem(0)
  305. , n_nonzero(0)
  306. , vec_state(0)
  307. , values(NULL)
  308. , row_indices(NULL)
  309. , col_ptrs(NULL)
  310. {
  311. arma_extra_debug_sigprint_this(this);
  312. const unwrap<T1> locs_tmp( locations_expr.get_ref() );
  313. const unwrap<T2> vals_tmp( vals_expr.get_ref() );
  314. const Mat<uword>& locs = locs_tmp.M;
  315. const Mat<eT>& vals = vals_tmp.M;
  316. arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" );
  317. arma_debug_check( (locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows" );
  318. arma_debug_check( (locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values" );
  319. init_cold(in_n_rows, in_n_cols);
  320. // Ensure that there are no zeros, unless the user asked not to.
  321. if(check_for_zeros)
  322. {
  323. const uword N_old = vals.n_elem;
  324. uword N_new = 0;
  325. for(uword i = 0; i < N_old; ++i)
  326. {
  327. if(vals[i] != eT(0)) { ++N_new; }
  328. }
  329. if(N_new != N_old)
  330. {
  331. Col<eT> filtered_vals(N_new);
  332. Mat<uword> filtered_locs(2, N_new);
  333. uword index = 0;
  334. for(uword i = 0; i < N_old; ++i)
  335. {
  336. if(vals[i] != eT(0))
  337. {
  338. filtered_vals[index] = vals[i];
  339. filtered_locs.at(0, index) = locs.at(0, i);
  340. filtered_locs.at(1, index) = locs.at(1, i);
  341. ++index;
  342. }
  343. }
  344. init_batch_std(filtered_locs, filtered_vals, sort_locations);
  345. }
  346. else
  347. {
  348. init_batch_std(locs, vals, sort_locations);
  349. }
  350. }
  351. else
  352. {
  353. init_batch_std(locs, vals, sort_locations);
  354. }
  355. }
  356. template<typename eT>
  357. template<typename T1, typename T2>
  358. inline
  359. SpMat<eT>::SpMat(const bool add_values, const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_expr, const uword in_n_rows, const uword in_n_cols, const bool sort_locations, const bool check_for_zeros)
  360. : n_rows(0)
  361. , n_cols(0)
  362. , n_elem(0)
  363. , n_nonzero(0)
  364. , vec_state(0)
  365. , values(NULL)
  366. , row_indices(NULL)
  367. , col_ptrs(NULL)
  368. {
  369. arma_extra_debug_sigprint_this(this);
  370. const unwrap<T1> locs_tmp( locations_expr.get_ref() );
  371. const unwrap<T2> vals_tmp( vals_expr.get_ref() );
  372. const Mat<uword>& locs = locs_tmp.M;
  373. const Mat<eT>& vals = vals_tmp.M;
  374. arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" );
  375. arma_debug_check( (locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows" );
  376. arma_debug_check( (locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values" );
  377. init_cold(in_n_rows, in_n_cols);
  378. // Ensure that there are no zeros, unless the user asked not to.
  379. if(check_for_zeros)
  380. {
  381. const uword N_old = vals.n_elem;
  382. uword N_new = 0;
  383. for(uword i = 0; i < N_old; ++i)
  384. {
  385. if(vals[i] != eT(0)) { ++N_new; }
  386. }
  387. if(N_new != N_old)
  388. {
  389. Col<eT> filtered_vals(N_new);
  390. Mat<uword> filtered_locs(2, N_new);
  391. uword index = 0;
  392. for(uword i = 0; i < N_old; ++i)
  393. {
  394. if(vals[i] != eT(0))
  395. {
  396. filtered_vals[index] = vals[i];
  397. filtered_locs.at(0, index) = locs.at(0, i);
  398. filtered_locs.at(1, index) = locs.at(1, i);
  399. ++index;
  400. }
  401. }
  402. add_values ? init_batch_add(filtered_locs, filtered_vals, sort_locations) : init_batch_std(filtered_locs, filtered_vals, sort_locations);
  403. }
  404. else
  405. {
  406. add_values ? init_batch_add(locs, vals, sort_locations) : init_batch_std(locs, vals, sort_locations);
  407. }
  408. }
  409. else
  410. {
  411. add_values ? init_batch_add(locs, vals, sort_locations) : init_batch_std(locs, vals, sort_locations);
  412. }
  413. }
  414. //! Insert a large number of values at once.
  415. //! Per CSC format, rowind_expr should be row indices,
  416. //! colptr_expr should column ptr indices locations,
  417. //! and values should be the corresponding values.
  418. //! In this constructor the size is explicitly given.
  419. //! Values are assumed to be sorted, and the size
  420. //! information is trusted
  421. template<typename eT>
  422. template<typename T1, typename T2, typename T3>
  423. inline
  424. SpMat<eT>::SpMat
  425. (
  426. const Base<uword,T1>& rowind_expr,
  427. const Base<uword,T2>& colptr_expr,
  428. const Base<eT, T3>& values_expr,
  429. const uword in_n_rows,
  430. const uword in_n_cols
  431. )
  432. : n_rows(0)
  433. , n_cols(0)
  434. , n_elem(0)
  435. , n_nonzero(0)
  436. , vec_state(0)
  437. , values(NULL)
  438. , row_indices(NULL)
  439. , col_ptrs(NULL)
  440. {
  441. arma_extra_debug_sigprint_this(this);
  442. const unwrap<T1> rowind_tmp( rowind_expr.get_ref() );
  443. const unwrap<T2> colptr_tmp( colptr_expr.get_ref() );
  444. const unwrap<T3> vals_tmp( values_expr.get_ref() );
  445. const Mat<uword>& rowind = rowind_tmp.M;
  446. const Mat<uword>& colptr = colptr_tmp.M;
  447. const Mat<eT>& vals = vals_tmp.M;
  448. arma_debug_check( (rowind.is_vec() == false), "SpMat::SpMat(): given 'rowind' object is not a vector" );
  449. arma_debug_check( (colptr.is_vec() == false), "SpMat::SpMat(): given 'colptr' object is not a vector" );
  450. arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" );
  451. // Resize to correct number of elements (this also sets n_nonzero)
  452. init_cold(in_n_rows, in_n_cols, vals.n_elem);
  453. arma_debug_check( (rowind.n_elem != vals.n_elem), "SpMat::SpMat(): number of row indices is not equal to number of values" );
  454. arma_debug_check( (colptr.n_elem != (n_cols+1) ), "SpMat::SpMat(): number of column pointers is not equal to n_cols+1" );
  455. // copy supplied values into sparse matrix -- not checked for consistency
  456. arrayops::copy(access::rwp(row_indices), rowind.memptr(), rowind.n_elem );
  457. arrayops::copy(access::rwp(col_ptrs), colptr.memptr(), colptr.n_elem );
  458. arrayops::copy(access::rwp(values), vals.memptr(), vals.n_elem );
  459. // important: set the sentinel as well
  460. access::rw(col_ptrs[n_cols + 1]) = std::numeric_limits<uword>::max();
  461. // make sure no zeros are stored
  462. remove_zeros();
  463. }
  464. template<typename eT>
  465. inline
  466. SpMat<eT>&
  467. SpMat<eT>::operator=(const eT val)
  468. {
  469. arma_extra_debug_sigprint();
  470. if(val != eT(0))
  471. {
  472. // Resize to 1x1 then set that to the right value.
  473. init(1, 1, 1); // Sets col_ptrs to 0.
  474. // Manually set element.
  475. access::rw(values[0]) = val;
  476. access::rw(row_indices[0]) = 0;
  477. access::rw(col_ptrs[1]) = 1;
  478. }
  479. else
  480. {
  481. init(0, 0);
  482. }
  483. return *this;
  484. }
  485. template<typename eT>
  486. inline
  487. SpMat<eT>&
  488. SpMat<eT>::operator*=(const eT val)
  489. {
  490. arma_extra_debug_sigprint();
  491. if(val != eT(0))
  492. {
  493. sync_csc();
  494. invalidate_cache();
  495. const uword n_nz = n_nonzero;
  496. eT* vals = access::rwp(values);
  497. bool has_zero = false;
  498. for(uword i=0; i<n_nz; ++i)
  499. {
  500. eT& vals_i = vals[i];
  501. vals_i *= val;
  502. if(vals_i == eT(0)) { has_zero = true; }
  503. }
  504. if(has_zero) { remove_zeros(); }
  505. }
  506. else
  507. {
  508. (*this).zeros();
  509. }
  510. return *this;
  511. }
  512. template<typename eT>
  513. inline
  514. SpMat<eT>&
  515. SpMat<eT>::operator/=(const eT val)
  516. {
  517. arma_extra_debug_sigprint();
  518. arma_debug_check( (val == eT(0)), "element-wise division: division by zero" );
  519. sync_csc();
  520. invalidate_cache();
  521. const uword n_nz = n_nonzero;
  522. eT* vals = access::rwp(values);
  523. bool has_zero = false;
  524. for(uword i=0; i<n_nz; ++i)
  525. {
  526. eT& vals_i = vals[i];
  527. vals_i /= val;
  528. if(vals_i == eT(0)) { has_zero = true; }
  529. }
  530. if(has_zero) { remove_zeros(); }
  531. return *this;
  532. }
  533. template<typename eT>
  534. inline
  535. SpMat<eT>&
  536. SpMat<eT>::operator=(const SpMat<eT>& x)
  537. {
  538. arma_extra_debug_sigprint();
  539. init(x);
  540. return *this;
  541. }
  542. template<typename eT>
  543. inline
  544. SpMat<eT>&
  545. SpMat<eT>::operator+=(const SpMat<eT>& x)
  546. {
  547. arma_extra_debug_sigprint();
  548. sync_csc();
  549. SpMat<eT> out = (*this) + x;
  550. steal_mem(out);
  551. return *this;
  552. }
  553. template<typename eT>
  554. inline
  555. SpMat<eT>&
  556. SpMat<eT>::operator-=(const SpMat<eT>& x)
  557. {
  558. arma_extra_debug_sigprint();
  559. sync_csc();
  560. SpMat<eT> out = (*this) - x;
  561. steal_mem(out);
  562. return *this;
  563. }
  564. template<typename eT>
  565. inline
  566. SpMat<eT>&
  567. SpMat<eT>::operator*=(const SpMat<eT>& y)
  568. {
  569. arma_extra_debug_sigprint();
  570. sync_csc();
  571. SpMat<eT> z = (*this) * y;
  572. steal_mem(z);
  573. return *this;
  574. }
  575. // This is in-place element-wise matrix multiplication.
  576. template<typename eT>
  577. inline
  578. SpMat<eT>&
  579. SpMat<eT>::operator%=(const SpMat<eT>& y)
  580. {
  581. arma_extra_debug_sigprint();
  582. sync_csc();
  583. SpMat<eT> z = (*this) % y;
  584. steal_mem(z);
  585. return *this;
  586. }
  587. template<typename eT>
  588. inline
  589. SpMat<eT>&
  590. SpMat<eT>::operator/=(const SpMat<eT>& x)
  591. {
  592. arma_extra_debug_sigprint();
  593. // NOTE: use of this function is not advised; it is implemented only for completeness
  594. arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division");
  595. for(uword c = 0; c < n_cols; ++c)
  596. for(uword r = 0; r < n_rows; ++r)
  597. {
  598. at(r, c) /= x.at(r, c);
  599. }
  600. return *this;
  601. }
  602. template<typename eT>
  603. template<typename T1, typename op_type>
  604. inline
  605. SpMat<eT>::SpMat(const SpToDOp<T1, op_type>& expr)
  606. : n_rows(0)
  607. , n_cols(0)
  608. , n_elem(0)
  609. , n_nonzero(0)
  610. , vec_state(0)
  611. , values(NULL)
  612. , row_indices(NULL)
  613. , col_ptrs(NULL)
  614. {
  615. arma_extra_debug_sigprint_this(this);
  616. typedef typename T1::elem_type T;
  617. // Make sure the type is compatible.
  618. arma_type_check(( is_same_type< eT, T >::no ));
  619. op_type::apply(*this, expr);
  620. }
  621. // Construct a complex matrix out of two non-complex matrices
  622. template<typename eT>
  623. template<typename T1, typename T2>
  624. inline
  625. SpMat<eT>::SpMat
  626. (
  627. const SpBase<typename SpMat<eT>::pod_type, T1>& A,
  628. const SpBase<typename SpMat<eT>::pod_type, T2>& B
  629. )
  630. : n_rows(0)
  631. , n_cols(0)
  632. , n_elem(0)
  633. , n_nonzero(0)
  634. , vec_state(0)
  635. , values(NULL)
  636. , row_indices(NULL)
  637. , col_ptrs(NULL)
  638. {
  639. arma_extra_debug_sigprint();
  640. typedef typename T1::elem_type T;
  641. // Make sure eT is complex and T is not (compile-time check).
  642. arma_type_check(( is_cx<eT>::no ));
  643. arma_type_check(( is_cx< T>::yes ));
  644. // Compile-time abort if types are not compatible.
  645. arma_type_check(( is_same_type< std::complex<T>, eT >::no ));
  646. const unwrap_spmat<T1> tmp1(A.get_ref());
  647. const unwrap_spmat<T2> tmp2(B.get_ref());
  648. const SpMat<T>& X = tmp1.M;
  649. const SpMat<T>& Y = tmp2.M;
  650. arma_debug_assert_same_size(X.n_rows, X.n_cols, Y.n_rows, Y.n_cols, "SpMat()");
  651. const uword l_n_rows = X.n_rows;
  652. const uword l_n_cols = X.n_cols;
  653. // Set size of matrix correctly.
  654. init_cold(l_n_rows, l_n_cols, n_unique(X, Y, op_n_unique_count()));
  655. // Now on a second iteration, fill it.
  656. typename SpMat<T>::const_iterator x_it = X.begin();
  657. typename SpMat<T>::const_iterator x_end = X.end();
  658. typename SpMat<T>::const_iterator y_it = Y.begin();
  659. typename SpMat<T>::const_iterator y_end = Y.end();
  660. uword cur_pos = 0;
  661. while((x_it != x_end) || (y_it != y_end))
  662. {
  663. if(x_it == y_it) // if we are at the same place
  664. {
  665. access::rw(values[cur_pos]) = std::complex<T>((T) *x_it, (T) *y_it);
  666. access::rw(row_indices[cur_pos]) = x_it.row();
  667. ++access::rw(col_ptrs[x_it.col() + 1]);
  668. ++x_it;
  669. ++y_it;
  670. }
  671. else
  672. {
  673. if((x_it.col() < y_it.col()) || ((x_it.col() == y_it.col()) && (x_it.row() < y_it.row()))) // if y is closer to the end
  674. {
  675. access::rw(values[cur_pos]) = std::complex<T>((T) *x_it, T(0));
  676. access::rw(row_indices[cur_pos]) = x_it.row();
  677. ++access::rw(col_ptrs[x_it.col() + 1]);
  678. ++x_it;
  679. }
  680. else // x is closer to the end
  681. {
  682. access::rw(values[cur_pos]) = std::complex<T>(T(0), (T) *y_it);
  683. access::rw(row_indices[cur_pos]) = y_it.row();
  684. ++access::rw(col_ptrs[y_it.col() + 1]);
  685. ++y_it;
  686. }
  687. }
  688. ++cur_pos;
  689. }
  690. // Now fix the column pointers; they are supposed to be a sum.
  691. for(uword c = 1; c <= n_cols; ++c)
  692. {
  693. access::rw(col_ptrs[c]) += col_ptrs[c - 1];
  694. }
  695. }
  696. template<typename eT>
  697. template<typename T1>
  698. inline
  699. SpMat<eT>::SpMat(const Base<eT, T1>& x)
  700. : n_rows(0)
  701. , n_cols(0)
  702. , n_elem(0)
  703. , n_nonzero(0)
  704. , vec_state(0)
  705. , values(NULL)
  706. , row_indices(NULL)
  707. , col_ptrs(NULL)
  708. {
  709. arma_extra_debug_sigprint_this(this);
  710. (*this).operator=(x);
  711. }
  712. template<typename eT>
  713. template<typename T1>
  714. inline
  715. SpMat<eT>&
  716. SpMat<eT>::operator=(const Base<eT, T1>& expr)
  717. {
  718. arma_extra_debug_sigprint();
  719. if(is_same_type< T1, Gen<Mat<eT>, gen_zeros> >::yes)
  720. {
  721. const Proxy<T1> P(expr.get_ref());
  722. (*this).zeros( P.get_n_rows(), P.get_n_cols() );
  723. return *this;
  724. }
  725. if(is_same_type< T1, Gen<Mat<eT>, gen_eye> >::yes)
  726. {
  727. const Proxy<T1> P(expr.get_ref());
  728. (*this).eye( P.get_n_rows(), P.get_n_cols() );
  729. return *this;
  730. }
  731. const quasi_unwrap<T1> tmp(expr.get_ref());
  732. const Mat<eT>& x = tmp.M;
  733. const uword x_n_rows = x.n_rows;
  734. const uword x_n_cols = x.n_cols;
  735. const uword x_n_elem = x.n_elem;
  736. // Count number of nonzero elements in base object.
  737. uword n = 0;
  738. const eT* x_mem = x.memptr();
  739. for(uword i = 0; i < x_n_elem; ++i)
  740. {
  741. n += (x_mem[i] != eT(0)) ? uword(1) : uword(0);
  742. }
  743. init(x_n_rows, x_n_cols, n);
  744. if(n == 0) { return *this; }
  745. // Now the memory is resized correctly; set nonzero elements.
  746. n = 0;
  747. for(uword j = 0; j < x_n_cols; ++j)
  748. for(uword i = 0; i < x_n_rows; ++i)
  749. {
  750. const eT val = (*x_mem); x_mem++;
  751. if(val != eT(0))
  752. {
  753. access::rw(values[n]) = val;
  754. access::rw(row_indices[n]) = i;
  755. access::rw(col_ptrs[j + 1])++;
  756. ++n;
  757. }
  758. }
  759. // Sum column counts to be column pointers.
  760. for(uword c = 1; c <= n_cols; ++c)
  761. {
  762. access::rw(col_ptrs[c]) += col_ptrs[c - 1];
  763. }
  764. return *this;
  765. }
  766. template<typename eT>
  767. template<typename T1>
  768. inline
  769. SpMat<eT>&
  770. SpMat<eT>::operator+=(const Base<eT, T1>& x)
  771. {
  772. arma_extra_debug_sigprint();
  773. sync_csc();
  774. return (*this).operator=( (*this) + x.get_ref() );
  775. }
  776. template<typename eT>
  777. template<typename T1>
  778. inline
  779. SpMat<eT>&
  780. SpMat<eT>::operator-=(const Base<eT, T1>& x)
  781. {
  782. arma_extra_debug_sigprint();
  783. sync_csc();
  784. return (*this).operator=( (*this) - x.get_ref() );
  785. }
  786. template<typename eT>
  787. template<typename T1>
  788. inline
  789. SpMat<eT>&
  790. SpMat<eT>::operator*=(const Base<eT, T1>& y)
  791. {
  792. arma_extra_debug_sigprint();
  793. sync_csc();
  794. const Proxy<T1> p(y.get_ref());
  795. arma_debug_assert_mul_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "matrix multiplication");
  796. // We assume the matrix structure is such that we will end up with a sparse
  797. // matrix. Assuming that every entry in the dense matrix is nonzero (which is
  798. // a fairly valid assumption), each row with any nonzero elements in it (in this
  799. // matrix) implies an entire nonzero column. Therefore, we iterate over all
  800. // the row_indices and count the number of rows with any elements in them
  801. // (using the quasi-linked-list idea from SYMBMM -- see spglue_times_meat.hpp).
  802. podarray<uword> index(n_rows);
  803. index.fill(n_rows); // Fill with invalid links.
  804. uword last_index = n_rows + 1;
  805. for(uword i = 0; i < n_nonzero; ++i)
  806. {
  807. if(index[row_indices[i]] == n_rows)
  808. {
  809. index[row_indices[i]] = last_index;
  810. last_index = row_indices[i];
  811. }
  812. }
  813. // Now count the number of rows which have nonzero elements.
  814. uword nonzero_rows = 0;
  815. while(last_index != n_rows + 1)
  816. {
  817. ++nonzero_rows;
  818. last_index = index[last_index];
  819. }
  820. SpMat<eT> z(arma_reserve_indicator(), n_rows, p.get_n_cols(), (nonzero_rows * p.get_n_cols())); // upper bound on size
  821. // Now we have to fill all the elements using a modification of the NUMBMM algorithm.
  822. uword cur_pos = 0;
  823. podarray<eT> partial_sums(n_rows);
  824. partial_sums.zeros();
  825. for(uword lcol = 0; lcol < n_cols; ++lcol)
  826. {
  827. const_iterator it = begin();
  828. const_iterator it_end = end();
  829. while(it != it_end)
  830. {
  831. const eT value = (*it);
  832. partial_sums[it.row()] += (value * p.at(it.col(), lcol));
  833. ++it;
  834. }
  835. // Now add all partial sums to the matrix.
  836. for(uword i = 0; i < n_rows; ++i)
  837. {
  838. if(partial_sums[i] != eT(0))
  839. {
  840. access::rw(z.values[cur_pos]) = partial_sums[i];
  841. access::rw(z.row_indices[cur_pos]) = i;
  842. ++access::rw(z.col_ptrs[lcol + 1]);
  843. //printf("colptr %d now %d\n", lcol + 1, z.col_ptrs[lcol + 1]);
  844. ++cur_pos;
  845. partial_sums[i] = 0; // Would it be faster to do this in batch later?
  846. }
  847. }
  848. }
  849. // Now fix the column pointers.
  850. for(uword c = 1; c <= z.n_cols; ++c)
  851. {
  852. access::rw(z.col_ptrs[c]) += z.col_ptrs[c - 1];
  853. }
  854. // Resize to final correct size.
  855. z.mem_resize(z.col_ptrs[z.n_cols]);
  856. // Now take the memory of the temporary matrix.
  857. steal_mem(z);
  858. return *this;
  859. }
  860. /**
  861. * Don't use this function. It's not mathematically well-defined and wastes
  862. * cycles to trash all your data. This is dumb.
  863. */
  864. template<typename eT>
  865. template<typename T1>
  866. inline
  867. SpMat<eT>&
  868. SpMat<eT>::operator/=(const Base<eT, T1>& x)
  869. {
  870. arma_extra_debug_sigprint();
  871. sync_csc();
  872. SpMat<eT> tmp = (*this) / x.get_ref();
  873. steal_mem(tmp);
  874. return *this;
  875. }
  876. template<typename eT>
  877. template<typename T1>
  878. inline
  879. SpMat<eT>&
  880. SpMat<eT>::operator%=(const Base<eT, T1>& x)
  881. {
  882. arma_extra_debug_sigprint();
  883. sync_csc();
  884. const Proxy<T1> p(x.get_ref());
  885. arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "element-wise multiplication");
  886. // Count the number of elements we will need.
  887. const_iterator it = begin();
  888. const_iterator it_end = end();
  889. uword new_n_nonzero = 0;
  890. while(it != it_end)
  891. {
  892. // use_at == false can't save us any work here
  893. if(((*it) * p.at(it.row(), it.col())) != eT(0))
  894. {
  895. ++new_n_nonzero;
  896. }
  897. ++it;
  898. }
  899. SpMat<eT> tmp(arma_reserve_indicator(), n_rows, n_cols, new_n_nonzero);
  900. const_iterator c_it = begin();
  901. const_iterator c_it_end = end();
  902. uword cur_pos = 0;
  903. while(c_it != c_it_end)
  904. {
  905. // use_at == false can't save us any work here
  906. const eT val = (*c_it) * p.at(c_it.row(), c_it.col());
  907. if(val != eT(0))
  908. {
  909. access::rw(tmp.values[cur_pos]) = val;
  910. access::rw(tmp.row_indices[cur_pos]) = c_it.row();
  911. ++access::rw(tmp.col_ptrs[c_it.col() + 1]);
  912. ++cur_pos;
  913. }
  914. ++c_it;
  915. }
  916. // Fix column pointers.
  917. for(uword c = 1; c <= n_cols; ++c)
  918. {
  919. access::rw(tmp.col_ptrs[c]) += tmp.col_ptrs[c - 1];
  920. }
  921. steal_mem(tmp);
  922. return *this;
  923. }
  924. template<typename eT>
  925. template<typename T1>
  926. inline
  927. SpMat<eT>::SpMat(const Op<T1, op_diagmat>& expr)
  928. : n_rows(0)
  929. , n_cols(0)
  930. , n_elem(0)
  931. , n_nonzero(0)
  932. , vec_state(0)
  933. , values(NULL)
  934. , row_indices(NULL)
  935. , col_ptrs(NULL)
  936. {
  937. arma_extra_debug_sigprint_this(this);
  938. (*this).operator=(expr);
  939. }
  940. template<typename eT>
  941. template<typename T1>
  942. inline
  943. SpMat<eT>&
  944. SpMat<eT>::operator=(const Op<T1, op_diagmat>& expr)
  945. {
  946. arma_extra_debug_sigprint();
  947. const diagmat_proxy<T1> P(expr.m);
  948. const uword max_n_nonzero = (std::min)(P.n_rows, P.n_cols);
  949. // resize memory to upper bound
  950. init(P.n_rows, P.n_cols, max_n_nonzero);
  951. uword count = 0;
  952. for(uword i=0; i < max_n_nonzero; ++i)
  953. {
  954. const eT val = P[i];
  955. if(val != eT(0))
  956. {
  957. access::rw(values[count]) = val;
  958. access::rw(row_indices[count]) = i;
  959. access::rw(col_ptrs[i + 1])++;
  960. ++count;
  961. }
  962. }
  963. // fix column pointers to be cumulative
  964. for(uword i = 1; i < n_cols + 1; ++i)
  965. {
  966. access::rw(col_ptrs[i]) += col_ptrs[i - 1];
  967. }
  968. // quick resize without reallocating memory and copying data
  969. access::rw( n_nonzero) = count;
  970. access::rw( values[count]) = eT(0);
  971. access::rw(row_indices[count]) = uword(0);
  972. return *this;
  973. }
  974. template<typename eT>
  975. template<typename T1>
  976. inline
  977. SpMat<eT>&
  978. SpMat<eT>::operator+=(const Op<T1, op_diagmat>& expr)
  979. {
  980. arma_extra_debug_sigprint();
  981. const SpMat<eT> tmp(expr);
  982. return (*this).operator+=(tmp);
  983. }
  984. template<typename eT>
  985. template<typename T1>
  986. inline
  987. SpMat<eT>&
  988. SpMat<eT>::operator-=(const Op<T1, op_diagmat>& expr)
  989. {
  990. arma_extra_debug_sigprint();
  991. const SpMat<eT> tmp(expr);
  992. return (*this).operator-=(tmp);
  993. }
  994. template<typename eT>
  995. template<typename T1>
  996. inline
  997. SpMat<eT>&
  998. SpMat<eT>::operator*=(const Op<T1, op_diagmat>& expr)
  999. {
  1000. arma_extra_debug_sigprint();
  1001. const SpMat<eT> tmp(expr);
  1002. return (*this).operator*=(tmp);
  1003. }
  1004. template<typename eT>
  1005. template<typename T1>
  1006. inline
  1007. SpMat<eT>&
  1008. SpMat<eT>::operator/=(const Op<T1, op_diagmat>& expr)
  1009. {
  1010. arma_extra_debug_sigprint();
  1011. const SpMat<eT> tmp(expr);
  1012. return (*this).operator/=(tmp);
  1013. }
  1014. template<typename eT>
  1015. template<typename T1>
  1016. inline
  1017. SpMat<eT>&
  1018. SpMat<eT>::operator%=(const Op<T1, op_diagmat>& expr)
  1019. {
  1020. arma_extra_debug_sigprint();
  1021. const SpMat<eT> tmp(expr);
  1022. return (*this).operator%=(tmp);
  1023. }
  1024. /**
  1025. * Functions on subviews.
  1026. */
  1027. template<typename eT>
  1028. inline
  1029. SpMat<eT>::SpMat(const SpSubview<eT>& X)
  1030. : n_rows(0)
  1031. , n_cols(0)
  1032. , n_elem(0)
  1033. , n_nonzero(0)
  1034. , vec_state(0)
  1035. , values(NULL)
  1036. , row_indices(NULL)
  1037. , col_ptrs(NULL)
  1038. {
  1039. arma_extra_debug_sigprint_this(this);
  1040. (*this).operator=(X);
  1041. }
  1042. template<typename eT>
  1043. inline
  1044. SpMat<eT>&
  1045. SpMat<eT>::operator=(const SpSubview<eT>& X)
  1046. {
  1047. arma_extra_debug_sigprint();
  1048. if(X.n_nonzero == 0) { zeros(X.n_rows, X.n_cols); return *this; }
  1049. X.m.sync_csc();
  1050. const bool alias = (this == &(X.m));
  1051. if(alias)
  1052. {
  1053. SpMat<eT> tmp(X);
  1054. steal_mem(tmp);
  1055. }
  1056. else
  1057. {
  1058. init(X.n_rows, X.n_cols, X.n_nonzero);
  1059. if(X.n_rows == X.m.n_rows)
  1060. {
  1061. const uword sv_col_start = X.aux_col1;
  1062. const uword sv_col_end = X.aux_col1 + X.n_cols - 1;
  1063. typename SpMat<eT>::const_col_iterator m_it = X.m.begin_col(sv_col_start);
  1064. typename SpMat<eT>::const_col_iterator m_it_end = X.m.end_col(sv_col_end);
  1065. uword count = 0;
  1066. while(m_it != m_it_end)
  1067. {
  1068. const uword m_it_col_adjusted = m_it.col() - sv_col_start;
  1069. access::rw(row_indices[count]) = m_it.row();
  1070. access::rw(values[count]) = (*m_it);
  1071. ++access::rw(col_ptrs[m_it_col_adjusted + 1]);
  1072. count++;
  1073. ++m_it;
  1074. }
  1075. }
  1076. else
  1077. {
  1078. typename SpSubview<eT>::const_iterator it = X.begin();
  1079. typename SpSubview<eT>::const_iterator it_end = X.end();
  1080. while(it != it_end)
  1081. {
  1082. const uword it_pos = it.pos();
  1083. access::rw(row_indices[it_pos]) = it.row();
  1084. access::rw(values[it_pos]) = (*it);
  1085. ++access::rw(col_ptrs[it.col() + 1]);
  1086. ++it;
  1087. }
  1088. }
  1089. // Now sum column pointers.
  1090. for(uword c = 1; c <= n_cols; ++c)
  1091. {
  1092. access::rw(col_ptrs[c]) += col_ptrs[c - 1];
  1093. }
  1094. }
  1095. return *this;
  1096. }
  1097. template<typename eT>
  1098. inline
  1099. SpMat<eT>&
  1100. SpMat<eT>::operator+=(const SpSubview<eT>& X)
  1101. {
  1102. arma_extra_debug_sigprint();
  1103. sync_csc();
  1104. SpMat<eT> tmp = (*this) + X;
  1105. steal_mem(tmp);
  1106. return *this;
  1107. }
  1108. template<typename eT>
  1109. inline
  1110. SpMat<eT>&
  1111. SpMat<eT>::operator-=(const SpSubview<eT>& X)
  1112. {
  1113. arma_extra_debug_sigprint();
  1114. sync_csc();
  1115. SpMat<eT> tmp = (*this) - X;
  1116. steal_mem(tmp);
  1117. return *this;
  1118. }
  1119. template<typename eT>
  1120. inline
  1121. SpMat<eT>&
  1122. SpMat<eT>::operator*=(const SpSubview<eT>& y)
  1123. {
  1124. arma_extra_debug_sigprint();
  1125. sync_csc();
  1126. SpMat<eT> z = (*this) * y;
  1127. steal_mem(z);
  1128. return *this;
  1129. }
  1130. template<typename eT>
  1131. inline
  1132. SpMat<eT>&
  1133. SpMat<eT>::operator%=(const SpSubview<eT>& x)
  1134. {
  1135. arma_extra_debug_sigprint();
  1136. sync_csc();
  1137. SpMat<eT> tmp = (*this) % x;
  1138. steal_mem(tmp);
  1139. return *this;
  1140. }
  1141. template<typename eT>
  1142. inline
  1143. SpMat<eT>&
  1144. SpMat<eT>::operator/=(const SpSubview<eT>& x)
  1145. {
  1146. arma_extra_debug_sigprint();
  1147. arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division");
  1148. // There is no pretty way to do this.
  1149. for(uword elem = 0; elem < n_elem; elem++)
  1150. {
  1151. at(elem) /= x(elem);
  1152. }
  1153. return *this;
  1154. }
  1155. template<typename eT>
  1156. inline
  1157. SpMat<eT>::SpMat(const spdiagview<eT>& X)
  1158. : n_rows(0)
  1159. , n_cols(0)
  1160. , n_elem(0)
  1161. , n_nonzero(0)
  1162. , vec_state(0)
  1163. , values(NULL)
  1164. , row_indices(NULL)
  1165. , col_ptrs(NULL)
  1166. {
  1167. arma_extra_debug_sigprint_this(this);
  1168. spdiagview<eT>::extract(*this, X);
  1169. }
  1170. template<typename eT>
  1171. inline
  1172. SpMat<eT>&
  1173. SpMat<eT>::operator=(const spdiagview<eT>& X)
  1174. {
  1175. arma_extra_debug_sigprint();
  1176. spdiagview<eT>::extract(*this, X);
  1177. return *this;
  1178. }
  1179. template<typename eT>
  1180. inline
  1181. SpMat<eT>&
  1182. SpMat<eT>::operator+=(const spdiagview<eT>& X)
  1183. {
  1184. arma_extra_debug_sigprint();
  1185. const SpMat<eT> tmp(X);
  1186. return (*this).operator+=(tmp);
  1187. }
  1188. template<typename eT>
  1189. inline
  1190. SpMat<eT>&
  1191. SpMat<eT>::operator-=(const spdiagview<eT>& X)
  1192. {
  1193. arma_extra_debug_sigprint();
  1194. const SpMat<eT> tmp(X);
  1195. return (*this).operator-=(tmp);
  1196. }
  1197. template<typename eT>
  1198. inline
  1199. SpMat<eT>&
  1200. SpMat<eT>::operator*=(const spdiagview<eT>& X)
  1201. {
  1202. arma_extra_debug_sigprint();
  1203. const SpMat<eT> tmp(X);
  1204. return (*this).operator*=(tmp);
  1205. }
  1206. template<typename eT>
  1207. inline
  1208. SpMat<eT>&
  1209. SpMat<eT>::operator%=(const spdiagview<eT>& X)
  1210. {
  1211. arma_extra_debug_sigprint();
  1212. const SpMat<eT> tmp(X);
  1213. return (*this).operator%=(tmp);
  1214. }
  1215. template<typename eT>
  1216. inline
  1217. SpMat<eT>&
  1218. SpMat<eT>::operator/=(const spdiagview<eT>& X)
  1219. {
  1220. arma_extra_debug_sigprint();
  1221. const SpMat<eT> tmp(X);
  1222. return (*this).operator/=(tmp);
  1223. }
  1224. template<typename eT>
  1225. template<typename T1, typename spop_type>
  1226. inline
  1227. SpMat<eT>::SpMat(const SpOp<T1, spop_type>& X)
  1228. : n_rows(0)
  1229. , n_cols(0)
  1230. , n_elem(0)
  1231. , n_nonzero(0)
  1232. , vec_state(0)
  1233. , values(NULL) // set in application of sparse operation
  1234. , row_indices(NULL)
  1235. , col_ptrs(NULL)
  1236. {
  1237. arma_extra_debug_sigprint_this(this);
  1238. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1239. spop_type::apply(*this, X);
  1240. sync_csc(); // in case apply() used element accessors
  1241. invalidate_cache(); // in case apply() modified the CSC representation
  1242. }
  1243. template<typename eT>
  1244. template<typename T1, typename spop_type>
  1245. inline
  1246. SpMat<eT>&
  1247. SpMat<eT>::operator=(const SpOp<T1, spop_type>& X)
  1248. {
  1249. arma_extra_debug_sigprint();
  1250. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1251. spop_type::apply(*this, X);
  1252. sync_csc(); // in case apply() used element accessors
  1253. invalidate_cache(); // in case apply() modified the CSC representation
  1254. return *this;
  1255. }
  1256. template<typename eT>
  1257. template<typename T1, typename spop_type>
  1258. inline
  1259. SpMat<eT>&
  1260. SpMat<eT>::operator+=(const SpOp<T1, spop_type>& X)
  1261. {
  1262. arma_extra_debug_sigprint();
  1263. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1264. sync_csc();
  1265. const SpMat<eT> m(X);
  1266. return (*this).operator+=(m);
  1267. }
  1268. template<typename eT>
  1269. template<typename T1, typename spop_type>
  1270. inline
  1271. SpMat<eT>&
  1272. SpMat<eT>::operator-=(const SpOp<T1, spop_type>& X)
  1273. {
  1274. arma_extra_debug_sigprint();
  1275. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1276. sync_csc();
  1277. const SpMat<eT> m(X);
  1278. return (*this).operator-=(m);
  1279. }
  1280. template<typename eT>
  1281. template<typename T1, typename spop_type>
  1282. inline
  1283. SpMat<eT>&
  1284. SpMat<eT>::operator*=(const SpOp<T1, spop_type>& X)
  1285. {
  1286. arma_extra_debug_sigprint();
  1287. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1288. sync_csc();
  1289. const SpMat<eT> m(X);
  1290. return (*this).operator*=(m);
  1291. }
  1292. template<typename eT>
  1293. template<typename T1, typename spop_type>
  1294. inline
  1295. SpMat<eT>&
  1296. SpMat<eT>::operator%=(const SpOp<T1, spop_type>& X)
  1297. {
  1298. arma_extra_debug_sigprint();
  1299. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1300. sync_csc();
  1301. const SpMat<eT> m(X);
  1302. return (*this).operator%=(m);
  1303. }
  1304. template<typename eT>
  1305. template<typename T1, typename spop_type>
  1306. inline
  1307. SpMat<eT>&
  1308. SpMat<eT>::operator/=(const SpOp<T1, spop_type>& X)
  1309. {
  1310. arma_extra_debug_sigprint();
  1311. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1312. sync_csc();
  1313. const SpMat<eT> m(X);
  1314. return (*this).operator/=(m);
  1315. }
  1316. template<typename eT>
  1317. template<typename T1, typename T2, typename spglue_type>
  1318. inline
  1319. SpMat<eT>::SpMat(const SpGlue<T1, T2, spglue_type>& X)
  1320. : n_rows(0)
  1321. , n_cols(0)
  1322. , n_elem(0)
  1323. , n_nonzero(0)
  1324. , vec_state(0)
  1325. , values(NULL)
  1326. , row_indices(NULL)
  1327. , col_ptrs(NULL)
  1328. {
  1329. arma_extra_debug_sigprint_this(this);
  1330. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1331. spglue_type::apply(*this, X);
  1332. sync_csc(); // in case apply() used element accessors
  1333. invalidate_cache(); // in case apply() modified the CSC representation
  1334. }
  1335. template<typename eT>
  1336. template<typename T1, typename T2, typename spglue_type>
  1337. inline
  1338. SpMat<eT>&
  1339. SpMat<eT>::operator=(const SpGlue<T1, T2, spglue_type>& X)
  1340. {
  1341. arma_extra_debug_sigprint();
  1342. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1343. spglue_type::apply(*this, X);
  1344. sync_csc(); // in case apply() used element accessors
  1345. invalidate_cache(); // in case apply() modified the CSC representation
  1346. return *this;
  1347. }
  1348. template<typename eT>
  1349. template<typename T1, typename T2, typename spglue_type>
  1350. inline
  1351. SpMat<eT>&
  1352. SpMat<eT>::operator+=(const SpGlue<T1, T2, spglue_type>& X)
  1353. {
  1354. arma_extra_debug_sigprint();
  1355. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1356. sync_csc();
  1357. const SpMat<eT> m(X);
  1358. return (*this).operator+=(m);
  1359. }
  1360. template<typename eT>
  1361. template<typename T1, typename T2, typename spglue_type>
  1362. inline
  1363. SpMat<eT>&
  1364. SpMat<eT>::operator-=(const SpGlue<T1, T2, spglue_type>& X)
  1365. {
  1366. arma_extra_debug_sigprint();
  1367. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1368. sync_csc();
  1369. const SpMat<eT> m(X);
  1370. return (*this).operator-=(m);
  1371. }
  1372. template<typename eT>
  1373. template<typename T1, typename T2, typename spglue_type>
  1374. inline
  1375. SpMat<eT>&
  1376. SpMat<eT>::operator*=(const SpGlue<T1, T2, spglue_type>& X)
  1377. {
  1378. arma_extra_debug_sigprint();
  1379. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1380. sync_csc();
  1381. const SpMat<eT> m(X);
  1382. return (*this).operator*=(m);
  1383. }
  1384. template<typename eT>
  1385. template<typename T1, typename T2, typename spglue_type>
  1386. inline
  1387. SpMat<eT>&
  1388. SpMat<eT>::operator%=(const SpGlue<T1, T2, spglue_type>& X)
  1389. {
  1390. arma_extra_debug_sigprint();
  1391. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1392. sync_csc();
  1393. const SpMat<eT> m(X);
  1394. return (*this).operator%=(m);
  1395. }
  1396. template<typename eT>
  1397. template<typename T1, typename T2, typename spglue_type>
  1398. inline
  1399. SpMat<eT>&
  1400. SpMat<eT>::operator/=(const SpGlue<T1, T2, spglue_type>& X)
  1401. {
  1402. arma_extra_debug_sigprint();
  1403. arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
  1404. sync_csc();
  1405. const SpMat<eT> m(X);
  1406. return (*this).operator/=(m);
  1407. }
  1408. template<typename eT>
  1409. template<typename T1, typename spop_type>
  1410. inline
  1411. SpMat<eT>::SpMat(const mtSpOp<eT, T1, spop_type>& X)
  1412. : n_rows(0)
  1413. , n_cols(0)
  1414. , n_elem(0)
  1415. , n_nonzero(0)
  1416. , vec_state(0)
  1417. , values(NULL)
  1418. , row_indices(NULL)
  1419. , col_ptrs(NULL)
  1420. {
  1421. arma_extra_debug_sigprint_this(this);
  1422. spop_type::apply(*this, X);
  1423. sync_csc(); // in case apply() used element accessors
  1424. invalidate_cache(); // in case apply() modified the CSC representation
  1425. }
  1426. template<typename eT>
  1427. template<typename T1, typename spop_type>
  1428. inline
  1429. SpMat<eT>&
  1430. SpMat<eT>::operator=(const mtSpOp<eT, T1, spop_type>& X)
  1431. {
  1432. arma_extra_debug_sigprint();
  1433. spop_type::apply(*this, X);
  1434. sync_csc(); // in case apply() used element accessors
  1435. invalidate_cache(); // in case apply() modified the CSC representation
  1436. return *this;
  1437. }
  1438. template<typename eT>
  1439. template<typename T1, typename spop_type>
  1440. inline
  1441. SpMat<eT>&
  1442. SpMat<eT>::operator+=(const mtSpOp<eT, T1, spop_type>& X)
  1443. {
  1444. arma_extra_debug_sigprint();
  1445. sync_csc();
  1446. const SpMat<eT> m(X);
  1447. return (*this).operator+=(m);
  1448. }
  1449. template<typename eT>
  1450. template<typename T1, typename spop_type>
  1451. inline
  1452. SpMat<eT>&
  1453. SpMat<eT>::operator-=(const mtSpOp<eT, T1, spop_type>& X)
  1454. {
  1455. arma_extra_debug_sigprint();
  1456. sync_csc();
  1457. const SpMat<eT> m(X);
  1458. return (*this).operator-=(m);
  1459. }
  1460. template<typename eT>
  1461. template<typename T1, typename spop_type>
  1462. inline
  1463. SpMat<eT>&
  1464. SpMat<eT>::operator*=(const mtSpOp<eT, T1, spop_type>& X)
  1465. {
  1466. arma_extra_debug_sigprint();
  1467. sync_csc();
  1468. const SpMat<eT> m(X);
  1469. return (*this).operator*=(m);
  1470. }
  1471. template<typename eT>
  1472. template<typename T1, typename spop_type>
  1473. inline
  1474. SpMat<eT>&
  1475. SpMat<eT>::operator%=(const mtSpOp<eT, T1, spop_type>& X)
  1476. {
  1477. arma_extra_debug_sigprint();
  1478. sync_csc();
  1479. const SpMat<eT> m(X);
  1480. return (*this).operator%=(m);
  1481. }
  1482. template<typename eT>
  1483. template<typename T1, typename spop_type>
  1484. inline
  1485. SpMat<eT>&
  1486. SpMat<eT>::operator/=(const mtSpOp<eT, T1, spop_type>& X)
  1487. {
  1488. arma_extra_debug_sigprint();
  1489. sync_csc();
  1490. const SpMat<eT> m(X);
  1491. return (*this).operator/=(m);
  1492. }
  1493. template<typename eT>
  1494. template<typename T1, typename T2, typename spglue_type>
  1495. inline
  1496. SpMat<eT>::SpMat(const mtSpGlue<eT, T1, T2, spglue_type>& X)
  1497. : n_rows(0)
  1498. , n_cols(0)
  1499. , n_elem(0)
  1500. , n_nonzero(0)
  1501. , vec_state(0)
  1502. , values(NULL)
  1503. , row_indices(NULL)
  1504. , col_ptrs(NULL)
  1505. {
  1506. arma_extra_debug_sigprint_this(this);
  1507. spglue_type::apply(*this, X);
  1508. sync_csc(); // in case apply() used element accessors
  1509. invalidate_cache(); // in case apply() modified the CSC representation
  1510. }
  1511. template<typename eT>
  1512. template<typename T1, typename T2, typename spglue_type>
  1513. inline
  1514. SpMat<eT>&
  1515. SpMat<eT>::operator=(const mtSpGlue<eT, T1, T2, spglue_type>& X)
  1516. {
  1517. arma_extra_debug_sigprint();
  1518. spglue_type::apply(*this, X);
  1519. sync_csc(); // in case apply() used element accessors
  1520. invalidate_cache(); // in case apply() modified the CSC representation
  1521. return *this;
  1522. }
  1523. template<typename eT>
  1524. template<typename T1, typename T2, typename spglue_type>
  1525. inline
  1526. SpMat<eT>&
  1527. SpMat<eT>::operator+=(const mtSpGlue<eT, T1, T2, spglue_type>& X)
  1528. {
  1529. arma_extra_debug_sigprint();
  1530. sync_csc();
  1531. const SpMat<eT> m(X);
  1532. return (*this).operator+=(m);
  1533. }
  1534. template<typename eT>
  1535. template<typename T1, typename T2, typename spglue_type>
  1536. inline
  1537. SpMat<eT>&
  1538. SpMat<eT>::operator-=(const mtSpGlue<eT, T1, T2, spglue_type>& X)
  1539. {
  1540. arma_extra_debug_sigprint();
  1541. sync_csc();
  1542. const SpMat<eT> m(X);
  1543. return (*this).operator-=(m);
  1544. }
  1545. template<typename eT>
  1546. template<typename T1, typename T2, typename spglue_type>
  1547. inline
  1548. SpMat<eT>&
  1549. SpMat<eT>::operator*=(const mtSpGlue<eT, T1, T2, spglue_type>& X)
  1550. {
  1551. arma_extra_debug_sigprint();
  1552. sync_csc();
  1553. const SpMat<eT> m(X);
  1554. return (*this).operator*=(m);
  1555. }
  1556. template<typename eT>
  1557. template<typename T1, typename T2, typename spglue_type>
  1558. inline
  1559. SpMat<eT>&
  1560. SpMat<eT>::operator%=(const mtSpGlue<eT, T1, T2, spglue_type>& X)
  1561. {
  1562. arma_extra_debug_sigprint();
  1563. sync_csc();
  1564. const SpMat<eT> m(X);
  1565. return (*this).operator%=(m);
  1566. }
  1567. template<typename eT>
  1568. template<typename T1, typename T2, typename spglue_type>
  1569. inline
  1570. SpMat<eT>&
  1571. SpMat<eT>::operator/=(const mtSpGlue<eT, T1, T2, spglue_type>& X)
  1572. {
  1573. arma_extra_debug_sigprint();
  1574. sync_csc();
  1575. const SpMat<eT> m(X);
  1576. return (*this).operator/=(m);
  1577. }
  1578. template<typename eT>
  1579. arma_inline
  1580. SpSubview_row<eT>
  1581. SpMat<eT>::row(const uword row_num)
  1582. {
  1583. arma_extra_debug_sigprint();
  1584. arma_debug_check(row_num >= n_rows, "SpMat::row(): out of bounds");
  1585. return SpSubview_row<eT>(*this, row_num);
  1586. }
  1587. template<typename eT>
  1588. arma_inline
  1589. const SpSubview_row<eT>
  1590. SpMat<eT>::row(const uword row_num) const
  1591. {
  1592. arma_extra_debug_sigprint();
  1593. arma_debug_check(row_num >= n_rows, "SpMat::row(): out of bounds");
  1594. return SpSubview_row<eT>(*this, row_num);
  1595. }
  1596. template<typename eT>
  1597. inline
  1598. SpSubview_row<eT>
  1599. SpMat<eT>::operator()(const uword row_num, const span& col_span)
  1600. {
  1601. arma_extra_debug_sigprint();
  1602. const bool col_all = col_span.whole;
  1603. const uword local_n_cols = n_cols;
  1604. const uword in_col1 = col_all ? 0 : col_span.a;
  1605. const uword in_col2 = col_span.b;
  1606. const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1;
  1607. arma_debug_check
  1608. (
  1609. (row_num >= n_rows)
  1610. ||
  1611. ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) )
  1612. ,
  1613. "SpMat::operator(): indices out of bounds or incorrectly used"
  1614. );
  1615. return SpSubview_row<eT>(*this, row_num, in_col1, submat_n_cols);
  1616. }
  1617. template<typename eT>
  1618. inline
  1619. const SpSubview_row<eT>
  1620. SpMat<eT>::operator()(const uword row_num, const span& col_span) const
  1621. {
  1622. arma_extra_debug_sigprint();
  1623. const bool col_all = col_span.whole;
  1624. const uword local_n_cols = n_cols;
  1625. const uword in_col1 = col_all ? 0 : col_span.a;
  1626. const uword in_col2 = col_span.b;
  1627. const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1;
  1628. arma_debug_check
  1629. (
  1630. (row_num >= n_rows)
  1631. ||
  1632. ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) )
  1633. ,
  1634. "SpMat::operator(): indices out of bounds or incorrectly used"
  1635. );
  1636. return SpSubview_row<eT>(*this, row_num, in_col1, submat_n_cols);
  1637. }
  1638. template<typename eT>
  1639. arma_inline
  1640. SpSubview_col<eT>
  1641. SpMat<eT>::col(const uword col_num)
  1642. {
  1643. arma_extra_debug_sigprint();
  1644. arma_debug_check(col_num >= n_cols, "SpMat::col(): out of bounds");
  1645. return SpSubview_col<eT>(*this, col_num);
  1646. }
  1647. template<typename eT>
  1648. arma_inline
  1649. const SpSubview_col<eT>
  1650. SpMat<eT>::col(const uword col_num) const
  1651. {
  1652. arma_extra_debug_sigprint();
  1653. arma_debug_check(col_num >= n_cols, "SpMat::col(): out of bounds");
  1654. return SpSubview_col<eT>(*this, col_num);
  1655. }
  1656. template<typename eT>
  1657. inline
  1658. SpSubview_col<eT>
  1659. SpMat<eT>::operator()(const span& row_span, const uword col_num)
  1660. {
  1661. arma_extra_debug_sigprint();
  1662. const bool row_all = row_span.whole;
  1663. const uword local_n_rows = n_rows;
  1664. const uword in_row1 = row_all ? 0 : row_span.a;
  1665. const uword in_row2 = row_span.b;
  1666. const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1;
  1667. arma_debug_check
  1668. (
  1669. (col_num >= n_cols)
  1670. ||
  1671. ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) )
  1672. ,
  1673. "SpMat::operator(): indices out of bounds or incorrectly used"
  1674. );
  1675. return SpSubview_col<eT>(*this, col_num, in_row1, submat_n_rows);
  1676. }
  1677. template<typename eT>
  1678. inline
  1679. const SpSubview_col<eT>
  1680. SpMat<eT>::operator()(const span& row_span, const uword col_num) const
  1681. {
  1682. arma_extra_debug_sigprint();
  1683. const bool row_all = row_span.whole;
  1684. const uword local_n_rows = n_rows;
  1685. const uword in_row1 = row_all ? 0 : row_span.a;
  1686. const uword in_row2 = row_span.b;
  1687. const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1;
  1688. arma_debug_check
  1689. (
  1690. (col_num >= n_cols)
  1691. ||
  1692. ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) )
  1693. ,
  1694. "SpMat::operator(): indices out of bounds or incorrectly used"
  1695. );
  1696. return SpSubview_col<eT>(*this, col_num, in_row1, submat_n_rows);
  1697. }
  1698. template<typename eT>
  1699. arma_inline
  1700. SpSubview<eT>
  1701. SpMat<eT>::rows(const uword in_row1, const uword in_row2)
  1702. {
  1703. arma_extra_debug_sigprint();
  1704. arma_debug_check
  1705. (
  1706. (in_row1 > in_row2) || (in_row2 >= n_rows),
  1707. "SpMat::rows(): indices out of bounds or incorrectly used"
  1708. );
  1709. const uword subview_n_rows = in_row2 - in_row1 + 1;
  1710. return SpSubview<eT>(*this, in_row1, 0, subview_n_rows, n_cols);
  1711. }
  1712. template<typename eT>
  1713. arma_inline
  1714. const SpSubview<eT>
  1715. SpMat<eT>::rows(const uword in_row1, const uword in_row2) const
  1716. {
  1717. arma_extra_debug_sigprint();
  1718. arma_debug_check
  1719. (
  1720. (in_row1 > in_row2) || (in_row2 >= n_rows),
  1721. "SpMat::rows(): indices out of bounds or incorrectly used"
  1722. );
  1723. const uword subview_n_rows = in_row2 - in_row1 + 1;
  1724. return SpSubview<eT>(*this, in_row1, 0, subview_n_rows, n_cols);
  1725. }
  1726. template<typename eT>
  1727. arma_inline
  1728. SpSubview<eT>
  1729. SpMat<eT>::cols(const uword in_col1, const uword in_col2)
  1730. {
  1731. arma_extra_debug_sigprint();
  1732. arma_debug_check
  1733. (
  1734. (in_col1 > in_col2) || (in_col2 >= n_cols),
  1735. "SpMat::cols(): indices out of bounds or incorrectly used"
  1736. );
  1737. const uword subview_n_cols = in_col2 - in_col1 + 1;
  1738. return SpSubview<eT>(*this, 0, in_col1, n_rows, subview_n_cols);
  1739. }
  1740. template<typename eT>
  1741. arma_inline
  1742. const SpSubview<eT>
  1743. SpMat<eT>::cols(const uword in_col1, const uword in_col2) const
  1744. {
  1745. arma_extra_debug_sigprint();
  1746. arma_debug_check
  1747. (
  1748. (in_col1 > in_col2) || (in_col2 >= n_cols),
  1749. "SpMat::cols(): indices out of bounds or incorrectly used"
  1750. );
  1751. const uword subview_n_cols = in_col2 - in_col1 + 1;
  1752. return SpSubview<eT>(*this, 0, in_col1, n_rows, subview_n_cols);
  1753. }
  1754. template<typename eT>
  1755. arma_inline
  1756. SpSubview<eT>
  1757. SpMat<eT>::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2)
  1758. {
  1759. arma_extra_debug_sigprint();
  1760. arma_debug_check
  1761. (
  1762. (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols),
  1763. "SpMat::submat(): indices out of bounds or incorrectly used"
  1764. );
  1765. const uword subview_n_rows = in_row2 - in_row1 + 1;
  1766. const uword subview_n_cols = in_col2 - in_col1 + 1;
  1767. return SpSubview<eT>(*this, in_row1, in_col1, subview_n_rows, subview_n_cols);
  1768. }
  1769. template<typename eT>
  1770. arma_inline
  1771. const SpSubview<eT>
  1772. SpMat<eT>::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const
  1773. {
  1774. arma_extra_debug_sigprint();
  1775. arma_debug_check
  1776. (
  1777. (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols),
  1778. "SpMat::submat(): indices out of bounds or incorrectly used"
  1779. );
  1780. const uword subview_n_rows = in_row2 - in_row1 + 1;
  1781. const uword subview_n_cols = in_col2 - in_col1 + 1;
  1782. return SpSubview<eT>(*this, in_row1, in_col1, subview_n_rows, subview_n_cols);
  1783. }
  1784. template<typename eT>
  1785. arma_inline
  1786. SpSubview<eT>
  1787. SpMat<eT>::submat(const uword in_row1, const uword in_col1, const SizeMat& s)
  1788. {
  1789. arma_extra_debug_sigprint();
  1790. const uword l_n_rows = n_rows;
  1791. const uword l_n_cols = n_cols;
  1792. const uword s_n_rows = s.n_rows;
  1793. const uword s_n_cols = s.n_cols;
  1794. arma_debug_check
  1795. (
  1796. ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)),
  1797. "SpMat::submat(): indices or size out of bounds"
  1798. );
  1799. return SpSubview<eT>(*this, in_row1, in_col1, s_n_rows, s_n_cols);
  1800. }
  1801. template<typename eT>
  1802. arma_inline
  1803. const SpSubview<eT>
  1804. SpMat<eT>::submat(const uword in_row1, const uword in_col1, const SizeMat& s) const
  1805. {
  1806. arma_extra_debug_sigprint();
  1807. const uword l_n_rows = n_rows;
  1808. const uword l_n_cols = n_cols;
  1809. const uword s_n_rows = s.n_rows;
  1810. const uword s_n_cols = s.n_cols;
  1811. arma_debug_check
  1812. (
  1813. ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)),
  1814. "SpMat::submat(): indices or size out of bounds"
  1815. );
  1816. return SpSubview<eT>(*this, in_row1, in_col1, s_n_rows, s_n_cols);
  1817. }
  1818. template<typename eT>
  1819. inline
  1820. SpSubview<eT>
  1821. SpMat<eT>::submat(const span& row_span, const span& col_span)
  1822. {
  1823. arma_extra_debug_sigprint();
  1824. const bool row_all = row_span.whole;
  1825. const bool col_all = col_span.whole;
  1826. const uword local_n_rows = n_rows;
  1827. const uword local_n_cols = n_cols;
  1828. const uword in_row1 = row_all ? 0 : row_span.a;
  1829. const uword in_row2 = row_span.b;
  1830. const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1;
  1831. const uword in_col1 = col_all ? 0 : col_span.a;
  1832. const uword in_col2 = col_span.b;
  1833. const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1;
  1834. arma_debug_check
  1835. (
  1836. ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) )
  1837. ||
  1838. ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) )
  1839. ,
  1840. "SpMat::submat(): indices out of bounds or incorrectly used"
  1841. );
  1842. return SpSubview<eT>(*this, in_row1, in_col1, submat_n_rows, submat_n_cols);
  1843. }
  1844. template<typename eT>
  1845. inline
  1846. const SpSubview<eT>
  1847. SpMat<eT>::submat(const span& row_span, const span& col_span) const
  1848. {
  1849. arma_extra_debug_sigprint();
  1850. const bool row_all = row_span.whole;
  1851. const bool col_all = col_span.whole;
  1852. const uword local_n_rows = n_rows;
  1853. const uword local_n_cols = n_cols;
  1854. const uword in_row1 = row_all ? 0 : row_span.a;
  1855. const uword in_row2 = row_span.b;
  1856. const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1;
  1857. const uword in_col1 = col_all ? 0 : col_span.a;
  1858. const uword in_col2 = col_span.b;
  1859. const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1;
  1860. arma_debug_check
  1861. (
  1862. ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) )
  1863. ||
  1864. ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) )
  1865. ,
  1866. "SpMat::submat(): indices out of bounds or incorrectly used"
  1867. );
  1868. return SpSubview<eT>(*this, in_row1, in_col1, submat_n_rows, submat_n_cols);
  1869. }
  1870. template<typename eT>
  1871. inline
  1872. SpSubview<eT>
  1873. SpMat<eT>::operator()(const span& row_span, const span& col_span)
  1874. {
  1875. arma_extra_debug_sigprint();
  1876. return submat(row_span, col_span);
  1877. }
  1878. template<typename eT>
  1879. inline
  1880. const SpSubview<eT>
  1881. SpMat<eT>::operator()(const span& row_span, const span& col_span) const
  1882. {
  1883. arma_extra_debug_sigprint();
  1884. return submat(row_span, col_span);
  1885. }
  1886. template<typename eT>
  1887. arma_inline
  1888. SpSubview<eT>
  1889. SpMat<eT>::operator()(const uword in_row1, const uword in_col1, const SizeMat& s)
  1890. {
  1891. arma_extra_debug_sigprint();
  1892. return (*this).submat(in_row1, in_col1, s);
  1893. }
  1894. template<typename eT>
  1895. arma_inline
  1896. const SpSubview<eT>
  1897. SpMat<eT>::operator()(const uword in_row1, const uword in_col1, const SizeMat& s) const
  1898. {
  1899. arma_extra_debug_sigprint();
  1900. return (*this).submat(in_row1, in_col1, s);
  1901. }
  1902. template<typename eT>
  1903. inline
  1904. SpSubview<eT>
  1905. SpMat<eT>::head_rows(const uword N)
  1906. {
  1907. arma_extra_debug_sigprint();
  1908. arma_debug_check( (N > n_rows), "SpMat::head_rows(): size out of bounds");
  1909. return SpSubview<eT>(*this, 0, 0, N, n_cols);
  1910. }
  1911. template<typename eT>
  1912. inline
  1913. const SpSubview<eT>
  1914. SpMat<eT>::head_rows(const uword N) const
  1915. {
  1916. arma_extra_debug_sigprint();
  1917. arma_debug_check( (N > n_rows), "SpMat::head_rows(): size out of bounds");
  1918. return SpSubview<eT>(*this, 0, 0, N, n_cols);
  1919. }
  1920. template<typename eT>
  1921. inline
  1922. SpSubview<eT>
  1923. SpMat<eT>::tail_rows(const uword N)
  1924. {
  1925. arma_extra_debug_sigprint();
  1926. arma_debug_check( (N > n_rows), "SpMat::tail_rows(): size out of bounds");
  1927. const uword start_row = n_rows - N;
  1928. return SpSubview<eT>(*this, start_row, 0, N, n_cols);
  1929. }
  1930. template<typename eT>
  1931. inline
  1932. const SpSubview<eT>
  1933. SpMat<eT>::tail_rows(const uword N) const
  1934. {
  1935. arma_extra_debug_sigprint();
  1936. arma_debug_check( (N > n_rows), "SpMat::tail_rows(): size out of bounds");
  1937. const uword start_row = n_rows - N;
  1938. return SpSubview<eT>(*this, start_row, 0, N, n_cols);
  1939. }
  1940. template<typename eT>
  1941. inline
  1942. SpSubview<eT>
  1943. SpMat<eT>::head_cols(const uword N)
  1944. {
  1945. arma_extra_debug_sigprint();
  1946. arma_debug_check( (N > n_cols), "SpMat::head_cols(): size out of bounds");
  1947. return SpSubview<eT>(*this, 0, 0, n_rows, N);
  1948. }
  1949. template<typename eT>
  1950. inline
  1951. const SpSubview<eT>
  1952. SpMat<eT>::head_cols(const uword N) const
  1953. {
  1954. arma_extra_debug_sigprint();
  1955. arma_debug_check( (N > n_cols), "SpMat::head_cols(): size out of bounds");
  1956. return SpSubview<eT>(*this, 0, 0, n_rows, N);
  1957. }
  1958. template<typename eT>
  1959. inline
  1960. SpSubview<eT>
  1961. SpMat<eT>::tail_cols(const uword N)
  1962. {
  1963. arma_extra_debug_sigprint();
  1964. arma_debug_check( (N > n_cols), "SpMat::tail_cols(): size out of bounds");
  1965. const uword start_col = n_cols - N;
  1966. return SpSubview<eT>(*this, 0, start_col, n_rows, N);
  1967. }
  1968. template<typename eT>
  1969. inline
  1970. const SpSubview<eT>
  1971. SpMat<eT>::tail_cols(const uword N) const
  1972. {
  1973. arma_extra_debug_sigprint();
  1974. arma_debug_check( (N > n_cols), "SpMat::tail_cols(): size out of bounds");
  1975. const uword start_col = n_cols - N;
  1976. return SpSubview<eT>(*this, 0, start_col, n_rows, N);
  1977. }
  1978. //! creation of spdiagview (diagonal)
  1979. template<typename eT>
  1980. inline
  1981. spdiagview<eT>
  1982. SpMat<eT>::diag(const sword in_id)
  1983. {
  1984. arma_extra_debug_sigprint();
  1985. const uword row_offset = (in_id < 0) ? uword(-in_id) : 0;
  1986. const uword col_offset = (in_id > 0) ? uword( in_id) : 0;
  1987. arma_debug_check
  1988. (
  1989. ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)),
  1990. "SpMat::diag(): requested diagonal out of bounds"
  1991. );
  1992. const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset);
  1993. return spdiagview<eT>(*this, row_offset, col_offset, len);
  1994. }
  1995. //! creation of spdiagview (diagonal)
  1996. template<typename eT>
  1997. inline
  1998. const spdiagview<eT>
  1999. SpMat<eT>::diag(const sword in_id) const
  2000. {
  2001. arma_extra_debug_sigprint();
  2002. const uword row_offset = uword( (in_id < 0) ? -in_id : 0 );
  2003. const uword col_offset = uword( (in_id > 0) ? in_id : 0 );
  2004. arma_debug_check
  2005. (
  2006. ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)),
  2007. "SpMat::diag(): requested diagonal out of bounds"
  2008. );
  2009. const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset);
  2010. return spdiagview<eT>(*this, row_offset, col_offset, len);
  2011. }
  2012. template<typename eT>
  2013. inline
  2014. void
  2015. SpMat<eT>::swap_rows(const uword in_row1, const uword in_row2)
  2016. {
  2017. arma_extra_debug_sigprint();
  2018. arma_debug_check( ((in_row1 >= n_rows) || (in_row2 >= n_rows)), "SpMat::swap_rows(): out of bounds" );
  2019. if(in_row1 == in_row2) { return; }
  2020. sync_csc();
  2021. invalidate_cache();
  2022. // The easier way to do this, instead of collecting all the elements in one row and then swapping with the other, will be
  2023. // to iterate over each column of the matrix (since we store in column-major format) and then swap the two elements in the two rows at that time.
  2024. // We will try to avoid using the at() call since it is expensive, instead preferring to use an iterator to track our position.
  2025. uword col1 = (in_row1 < in_row2) ? in_row1 : in_row2;
  2026. uword col2 = (in_row1 < in_row2) ? in_row2 : in_row1;
  2027. for(uword lcol = 0; lcol < n_cols; lcol++)
  2028. {
  2029. // If there is nothing in this column we can ignore it.
  2030. if(col_ptrs[lcol] == col_ptrs[lcol + 1])
  2031. {
  2032. continue;
  2033. }
  2034. // These will represent the positions of the items themselves.
  2035. uword loc1 = n_nonzero + 1;
  2036. uword loc2 = n_nonzero + 1;
  2037. for(uword search_pos = col_ptrs[lcol]; search_pos < col_ptrs[lcol + 1]; search_pos++)
  2038. {
  2039. if(row_indices[search_pos] == col1)
  2040. {
  2041. loc1 = search_pos;
  2042. }
  2043. if(row_indices[search_pos] == col2)
  2044. {
  2045. loc2 = search_pos;
  2046. break; // No need to look any further.
  2047. }
  2048. }
  2049. // There are four cases: we found both elements; we found one element (loc1); we found one element (loc2); we found zero elements.
  2050. // If we found zero elements no work needs to be done and we can continue to the next column.
  2051. if((loc1 != (n_nonzero + 1)) && (loc2 != (n_nonzero + 1)))
  2052. {
  2053. // This is an easy case: just swap the values. No index modifying necessary.
  2054. eT tmp = values[loc1];
  2055. access::rw(values[loc1]) = values[loc2];
  2056. access::rw(values[loc2]) = tmp;
  2057. }
  2058. else if(loc1 != (n_nonzero + 1)) // We only found loc1 and not loc2.
  2059. {
  2060. // We need to find the correct place to move our value to. It will be forward (not backwards) because in_row2 > in_row1.
  2061. // Each iteration of the loop swaps the current value (loc1) with (loc1 + 1); in this manner we move our value down to where it should be.
  2062. while(((loc1 + 1) < col_ptrs[lcol + 1]) && (row_indices[loc1 + 1] < in_row2))
  2063. {
  2064. // Swap both the values and the indices. The column should not change.
  2065. eT tmp = values[loc1];
  2066. access::rw(values[loc1]) = values[loc1 + 1];
  2067. access::rw(values[loc1 + 1]) = tmp;
  2068. uword tmp_index = row_indices[loc1];
  2069. access::rw(row_indices[loc1]) = row_indices[loc1 + 1];
  2070. access::rw(row_indices[loc1 + 1]) = tmp_index;
  2071. loc1++; // And increment the counter.
  2072. }
  2073. // Now set the row index correctly.
  2074. access::rw(row_indices[loc1]) = in_row2;
  2075. }
  2076. else if(loc2 != (n_nonzero + 1))
  2077. {
  2078. // We need to find the correct place to move our value to. It will be backwards (not forwards) because in_row1 < in_row2.
  2079. // Each iteration of the loop swaps the current value (loc2) with (loc2 - 1); in this manner we move our value up to where it should be.
  2080. while(((loc2 - 1) >= col_ptrs[lcol]) && (row_indices[loc2 - 1] > in_row1))
  2081. {
  2082. // Swap both the values and the indices. The column should not change.
  2083. eT tmp = values[loc2];
  2084. access::rw(values[loc2]) = values[loc2 - 1];
  2085. access::rw(values[loc2 - 1]) = tmp;
  2086. uword tmp_index = row_indices[loc2];
  2087. access::rw(row_indices[loc2]) = row_indices[loc2 - 1];
  2088. access::rw(row_indices[loc2 - 1]) = tmp_index;
  2089. loc2--; // And decrement the counter.
  2090. }
  2091. // Now set the row index correctly.
  2092. access::rw(row_indices[loc2]) = in_row1;
  2093. }
  2094. /* else: no need to swap anything; both values are zero */
  2095. }
  2096. }
  2097. template<typename eT>
  2098. inline
  2099. void
  2100. SpMat<eT>::swap_cols(const uword in_col1, const uword in_col2)
  2101. {
  2102. arma_extra_debug_sigprint();
  2103. arma_debug_check( ((in_col1 >= n_cols) || (in_col2 >= n_cols)), "SpMat::swap_cols(): out of bounds" );
  2104. if(in_col1 == in_col2) { return; }
  2105. // TODO: this is a rudimentary implementation
  2106. SpMat<eT> tmp = (*this);
  2107. tmp.col(in_col1) = (*this).col(in_col2);
  2108. tmp.col(in_col2) = (*this).col(in_col1);
  2109. steal_mem(tmp);
  2110. // for(uword lrow = 0; lrow < n_rows; ++lrow)
  2111. // {
  2112. // const eT tmp = at(lrow, in_col1);
  2113. // at(lrow, in_col1) = eT( at(lrow, in_col2) );
  2114. // at(lrow, in_col2) = tmp;
  2115. // }
  2116. }
  2117. template<typename eT>
  2118. inline
  2119. void
  2120. SpMat<eT>::shed_row(const uword row_num)
  2121. {
  2122. arma_extra_debug_sigprint();
  2123. arma_debug_check (row_num >= n_rows, "SpMat::shed_row(): out of bounds");
  2124. shed_rows (row_num, row_num);
  2125. }
  2126. template<typename eT>
  2127. inline
  2128. void
  2129. SpMat<eT>::shed_col(const uword col_num)
  2130. {
  2131. arma_extra_debug_sigprint();
  2132. arma_debug_check (col_num >= n_cols, "SpMat::shed_col(): out of bounds");
  2133. shed_cols(col_num, col_num);
  2134. }
  2135. template<typename eT>
  2136. inline
  2137. void
  2138. SpMat<eT>::shed_rows(const uword in_row1, const uword in_row2)
  2139. {
  2140. arma_extra_debug_sigprint();
  2141. arma_debug_check
  2142. (
  2143. (in_row1 > in_row2) || (in_row2 >= n_rows),
  2144. "SpMat::shed_rows(): indices out of bounds or incorectly used"
  2145. );
  2146. sync_csc();
  2147. SpMat<eT> newmat(n_rows - (in_row2 - in_row1 + 1), n_cols);
  2148. // First, count the number of elements we will be removing.
  2149. uword removing = 0;
  2150. for(uword i = 0; i < n_nonzero; ++i)
  2151. {
  2152. const uword lrow = row_indices[i];
  2153. if(lrow >= in_row1 && lrow <= in_row2)
  2154. {
  2155. ++removing;
  2156. }
  2157. }
  2158. // Obtain counts of the number of points in each column and store them as the
  2159. // (invalid) column pointers of the new matrix.
  2160. for(uword i = 1; i < n_cols + 1; ++i)
  2161. {
  2162. access::rw(newmat.col_ptrs[i]) = col_ptrs[i] - col_ptrs[i - 1];
  2163. }
  2164. // Now initialize memory for the new matrix.
  2165. newmat.mem_resize(n_nonzero - removing);
  2166. // Now, copy over the elements.
  2167. // i is the index in the old matrix; j is the index in the new matrix.
  2168. const_iterator it = begin();
  2169. const_iterator it_end = end();
  2170. uword j = 0; // The index in the new matrix.
  2171. while(it != it_end)
  2172. {
  2173. const uword lrow = it.row();
  2174. const uword lcol = it.col();
  2175. if(lrow >= in_row1 && lrow <= in_row2)
  2176. {
  2177. // This element is being removed. Subtract it from the column counts.
  2178. --access::rw(newmat.col_ptrs[lcol + 1]);
  2179. }
  2180. else
  2181. {
  2182. // This element is being kept. We may need to map the row index,
  2183. // if it is past the section of rows we are removing.
  2184. if(lrow > in_row2)
  2185. {
  2186. access::rw(newmat.row_indices[j]) = lrow - (in_row2 - in_row1 + 1);
  2187. }
  2188. else
  2189. {
  2190. access::rw(newmat.row_indices[j]) = lrow;
  2191. }
  2192. access::rw(newmat.values[j]) = (*it);
  2193. ++j; // Increment index in new matrix.
  2194. }
  2195. ++it;
  2196. }
  2197. // Finally, sum the column counts so they are correct column pointers.
  2198. for(uword i = 1; i < n_cols + 1; ++i)
  2199. {
  2200. access::rw(newmat.col_ptrs[i]) += newmat.col_ptrs[i - 1];
  2201. }
  2202. // Now steal the memory of the new matrix.
  2203. steal_mem(newmat);
  2204. }
  2205. template<typename eT>
  2206. inline
  2207. void
  2208. SpMat<eT>::shed_cols(const uword in_col1, const uword in_col2)
  2209. {
  2210. arma_extra_debug_sigprint();
  2211. arma_debug_check
  2212. (
  2213. (in_col1 > in_col2) || (in_col2 >= n_cols),
  2214. "SpMat::shed_cols(): indices out of bounds or incorrectly used"
  2215. );
  2216. sync_csc();
  2217. invalidate_cache();
  2218. // First we find the locations in values and row_indices for the column entries.
  2219. uword col_beg = col_ptrs[in_col1];
  2220. uword col_end = col_ptrs[in_col2 + 1];
  2221. // Then we find the number of entries in the column.
  2222. uword diff = col_end - col_beg;
  2223. if(diff > 0)
  2224. {
  2225. eT* new_values = memory::acquire<eT> (n_nonzero - diff);
  2226. uword* new_row_indices = memory::acquire<uword>(n_nonzero - diff);
  2227. // Copy first part.
  2228. if(col_beg != 0)
  2229. {
  2230. arrayops::copy(new_values, values, col_beg);
  2231. arrayops::copy(new_row_indices, row_indices, col_beg);
  2232. }
  2233. // Copy second part.
  2234. if(col_end != n_nonzero)
  2235. {
  2236. arrayops::copy(new_values + col_beg, values + col_end, n_nonzero - col_end);
  2237. arrayops::copy(new_row_indices + col_beg, row_indices + col_end, n_nonzero - col_end);
  2238. }
  2239. if(values) { memory::release(access::rw(values)); }
  2240. if(row_indices) { memory::release(access::rw(row_indices)); }
  2241. access::rw(values) = new_values;
  2242. access::rw(row_indices) = new_row_indices;
  2243. // Update counts and such.
  2244. access::rw(n_nonzero) -= diff;
  2245. }
  2246. // Update column pointers.
  2247. const uword new_n_cols = n_cols - ((in_col2 - in_col1) + 1);
  2248. uword* new_col_ptrs = memory::acquire<uword>(new_n_cols + 2);
  2249. new_col_ptrs[new_n_cols + 1] = std::numeric_limits<uword>::max();
  2250. // Copy first set of columns (no manipulation required).
  2251. if(in_col1 != 0)
  2252. {
  2253. arrayops::copy(new_col_ptrs, col_ptrs, in_col1);
  2254. }
  2255. // Copy second set of columns (manipulation required).
  2256. uword cur_col = in_col1;
  2257. for(uword i = in_col2 + 1; i <= n_cols; ++i, ++cur_col)
  2258. {
  2259. new_col_ptrs[cur_col] = col_ptrs[i] - diff;
  2260. }
  2261. if(col_ptrs) { memory::release(access::rw(col_ptrs)); }
  2262. access::rw(col_ptrs) = new_col_ptrs;
  2263. // We update the element and column counts, and we're done.
  2264. access::rw(n_cols) = new_n_cols;
  2265. access::rw(n_elem) = n_cols * n_rows;
  2266. }
  2267. /**
  2268. * Element access; acces the i'th element (works identically to the Mat accessors).
  2269. * If there is nothing at element i, 0 is returned.
  2270. */
  2271. template<typename eT>
  2272. arma_inline
  2273. arma_warn_unused
  2274. SpMat_MapMat_val<eT>
  2275. SpMat<eT>::operator[](const uword i)
  2276. {
  2277. const uword in_col = i / n_rows;
  2278. const uword in_row = i % n_rows;
  2279. return SpMat_MapMat_val<eT>((*this), cache, in_row, in_col);
  2280. }
  2281. template<typename eT>
  2282. arma_inline
  2283. arma_warn_unused
  2284. eT
  2285. SpMat<eT>::operator[](const uword i) const
  2286. {
  2287. return get_value(i);
  2288. }
  2289. template<typename eT>
  2290. arma_inline
  2291. arma_warn_unused
  2292. SpMat_MapMat_val<eT>
  2293. SpMat<eT>::at(const uword i)
  2294. {
  2295. const uword in_col = i / n_rows;
  2296. const uword in_row = i % n_rows;
  2297. return SpMat_MapMat_val<eT>((*this), cache, in_row, in_col);
  2298. }
  2299. template<typename eT>
  2300. arma_inline
  2301. arma_warn_unused
  2302. eT
  2303. SpMat<eT>::at(const uword i) const
  2304. {
  2305. return get_value(i);
  2306. }
  2307. template<typename eT>
  2308. arma_inline
  2309. arma_warn_unused
  2310. SpMat_MapMat_val<eT>
  2311. SpMat<eT>::operator()(const uword i)
  2312. {
  2313. arma_debug_check( (i >= n_elem), "SpMat::operator(): out of bounds");
  2314. const uword in_col = i / n_rows;
  2315. const uword in_row = i % n_rows;
  2316. return SpMat_MapMat_val<eT>((*this), cache, in_row, in_col);
  2317. }
  2318. template<typename eT>
  2319. arma_inline
  2320. arma_warn_unused
  2321. eT
  2322. SpMat<eT>::operator()(const uword i) const
  2323. {
  2324. arma_debug_check( (i >= n_elem), "SpMat::operator(): out of bounds");
  2325. return get_value(i);
  2326. }
  2327. /**
  2328. * Element access; access the element at row in_rows and column in_col.
  2329. * If there is nothing at that position, 0 is returned.
  2330. */
  2331. template<typename eT>
  2332. arma_inline
  2333. arma_warn_unused
  2334. SpMat_MapMat_val<eT>
  2335. SpMat<eT>::at(const uword in_row, const uword in_col)
  2336. {
  2337. return SpMat_MapMat_val<eT>((*this), cache, in_row, in_col);
  2338. }
  2339. template<typename eT>
  2340. arma_inline
  2341. arma_warn_unused
  2342. eT
  2343. SpMat<eT>::at(const uword in_row, const uword in_col) const
  2344. {
  2345. return get_value(in_row, in_col);
  2346. }
  2347. template<typename eT>
  2348. arma_inline
  2349. arma_warn_unused
  2350. SpMat_MapMat_val<eT>
  2351. SpMat<eT>::operator()(const uword in_row, const uword in_col)
  2352. {
  2353. arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds");
  2354. return SpMat_MapMat_val<eT>((*this), cache, in_row, in_col);
  2355. }
  2356. template<typename eT>
  2357. arma_inline
  2358. arma_warn_unused
  2359. eT
  2360. SpMat<eT>::operator()(const uword in_row, const uword in_col) const
  2361. {
  2362. arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds");
  2363. return get_value(in_row, in_col);
  2364. }
  2365. /**
  2366. * Check if matrix is empty (no size, no values).
  2367. */
  2368. template<typename eT>
  2369. arma_inline
  2370. arma_warn_unused
  2371. bool
  2372. SpMat<eT>::is_empty() const
  2373. {
  2374. return (n_elem == 0);
  2375. }
  2376. //! returns true if the object can be interpreted as a column or row vector
  2377. template<typename eT>
  2378. arma_inline
  2379. arma_warn_unused
  2380. bool
  2381. SpMat<eT>::is_vec() const
  2382. {
  2383. return ( (n_rows == 1) || (n_cols == 1) );
  2384. }
  2385. //! returns true if the object can be interpreted as a row vector
  2386. template<typename eT>
  2387. arma_inline
  2388. arma_warn_unused
  2389. bool
  2390. SpMat<eT>::is_rowvec() const
  2391. {
  2392. return (n_rows == 1);
  2393. }
  2394. //! returns true if the object can be interpreted as a column vector
  2395. template<typename eT>
  2396. arma_inline
  2397. arma_warn_unused
  2398. bool
  2399. SpMat<eT>::is_colvec() const
  2400. {
  2401. return (n_cols == 1);
  2402. }
  2403. //! returns true if the object has the same number of non-zero rows and columnns
  2404. template<typename eT>
  2405. arma_inline
  2406. arma_warn_unused
  2407. bool
  2408. SpMat<eT>::is_square() const
  2409. {
  2410. return (n_rows == n_cols);
  2411. }
  2412. //! returns true if all of the elements are finite
  2413. template<typename eT>
  2414. inline
  2415. arma_warn_unused
  2416. bool
  2417. SpMat<eT>::is_finite() const
  2418. {
  2419. arma_extra_debug_sigprint();
  2420. sync_csc();
  2421. return arrayops::is_finite(values, n_nonzero);
  2422. }
  2423. template<typename eT>
  2424. inline
  2425. arma_warn_unused
  2426. bool
  2427. SpMat<eT>::is_symmetric() const
  2428. {
  2429. arma_extra_debug_sigprint();
  2430. const SpMat<eT>& A = (*this);
  2431. if(A.n_rows != A.n_cols) { return false; }
  2432. const SpMat<eT> tmp = A - A.st();
  2433. return (tmp.n_nonzero == uword(0));
  2434. }
  2435. template<typename eT>
  2436. inline
  2437. arma_warn_unused
  2438. bool
  2439. SpMat<eT>::is_symmetric(const typename get_pod_type<elem_type>::result tol) const
  2440. {
  2441. arma_extra_debug_sigprint();
  2442. typedef typename get_pod_type<eT>::result T;
  2443. if(tol == T(0)) { return (*this).is_symmetric(); }
  2444. arma_debug_check( (tol < T(0)), "is_symmetric(): parameter 'tol' must be >= 0" );
  2445. const SpMat<eT>& A = (*this);
  2446. if(A.n_rows != A.n_cols) { return false; }
  2447. const T norm_A = as_scalar( arma::max(sum(abs(A), 1), 0) );
  2448. if(norm_A == T(0)) { return true; }
  2449. const T norm_A_Ast = as_scalar( arma::max(sum(abs(A - A.st()), 1), 0) );
  2450. return ( (norm_A_Ast / norm_A) <= tol );
  2451. }
  2452. template<typename eT>
  2453. inline
  2454. arma_warn_unused
  2455. bool
  2456. SpMat<eT>::is_hermitian() const
  2457. {
  2458. arma_extra_debug_sigprint();
  2459. const SpMat<eT>& A = (*this);
  2460. if(A.n_rows != A.n_cols) { return false; }
  2461. const SpMat<eT> tmp = A - A.t();
  2462. return (tmp.n_nonzero == uword(0));
  2463. }
  2464. template<typename eT>
  2465. inline
  2466. arma_warn_unused
  2467. bool
  2468. SpMat<eT>::is_hermitian(const typename get_pod_type<elem_type>::result tol) const
  2469. {
  2470. arma_extra_debug_sigprint();
  2471. typedef typename get_pod_type<eT>::result T;
  2472. if(tol == T(0)) { return (*this).is_hermitian(); }
  2473. arma_debug_check( (tol < T(0)), "is_hermitian(): parameter 'tol' must be >= 0" );
  2474. const SpMat<eT>& A = (*this);
  2475. if(A.n_rows != A.n_cols) { return false; }
  2476. const T norm_A = as_scalar( arma::max(sum(abs(A), 1), 0) );
  2477. if(norm_A == T(0)) { return true; }
  2478. const T norm_A_At = as_scalar( arma::max(sum(abs(A - A.t()), 1), 0) );
  2479. return ( (norm_A_At / norm_A) <= tol );
  2480. }
  2481. template<typename eT>
  2482. inline
  2483. arma_warn_unused
  2484. bool
  2485. SpMat<eT>::has_inf() const
  2486. {
  2487. arma_extra_debug_sigprint();
  2488. sync_csc();
  2489. return arrayops::has_inf(values, n_nonzero);
  2490. }
  2491. template<typename eT>
  2492. inline
  2493. arma_warn_unused
  2494. bool
  2495. SpMat<eT>::has_nan() const
  2496. {
  2497. arma_extra_debug_sigprint();
  2498. sync_csc();
  2499. return arrayops::has_nan(values, n_nonzero);
  2500. }
  2501. //! returns true if the given index is currently in range
  2502. template<typename eT>
  2503. arma_inline
  2504. arma_warn_unused
  2505. bool
  2506. SpMat<eT>::in_range(const uword i) const
  2507. {
  2508. return (i < n_elem);
  2509. }
  2510. //! returns true if the given start and end indices are currently in range
  2511. template<typename eT>
  2512. arma_inline
  2513. arma_warn_unused
  2514. bool
  2515. SpMat<eT>::in_range(const span& x) const
  2516. {
  2517. arma_extra_debug_sigprint();
  2518. if(x.whole == true)
  2519. {
  2520. return true;
  2521. }
  2522. else
  2523. {
  2524. const uword a = x.a;
  2525. const uword b = x.b;
  2526. return ( (a <= b) && (b < n_elem) );
  2527. }
  2528. }
  2529. //! returns true if the given location is currently in range
  2530. template<typename eT>
  2531. arma_inline
  2532. arma_warn_unused
  2533. bool
  2534. SpMat<eT>::in_range(const uword in_row, const uword in_col) const
  2535. {
  2536. return ( (in_row < n_rows) && (in_col < n_cols) );
  2537. }
  2538. template<typename eT>
  2539. arma_inline
  2540. arma_warn_unused
  2541. bool
  2542. SpMat<eT>::in_range(const span& row_span, const uword in_col) const
  2543. {
  2544. arma_extra_debug_sigprint();
  2545. if(row_span.whole == true)
  2546. {
  2547. return (in_col < n_cols);
  2548. }
  2549. else
  2550. {
  2551. const uword in_row1 = row_span.a;
  2552. const uword in_row2 = row_span.b;
  2553. return ( (in_row1 <= in_row2) && (in_row2 < n_rows) && (in_col < n_cols) );
  2554. }
  2555. }
  2556. template<typename eT>
  2557. arma_inline
  2558. arma_warn_unused
  2559. bool
  2560. SpMat<eT>::in_range(const uword in_row, const span& col_span) const
  2561. {
  2562. arma_extra_debug_sigprint();
  2563. if(col_span.whole == true)
  2564. {
  2565. return (in_row < n_rows);
  2566. }
  2567. else
  2568. {
  2569. const uword in_col1 = col_span.a;
  2570. const uword in_col2 = col_span.b;
  2571. return ( (in_row < n_rows) && (in_col1 <= in_col2) && (in_col2 < n_cols) );
  2572. }
  2573. }
  2574. template<typename eT>
  2575. arma_inline
  2576. arma_warn_unused
  2577. bool
  2578. SpMat<eT>::in_range(const span& row_span, const span& col_span) const
  2579. {
  2580. arma_extra_debug_sigprint();
  2581. const uword in_row1 = row_span.a;
  2582. const uword in_row2 = row_span.b;
  2583. const uword in_col1 = col_span.a;
  2584. const uword in_col2 = col_span.b;
  2585. const bool rows_ok = row_span.whole ? true : ( (in_row1 <= in_row2) && (in_row2 < n_rows) );
  2586. const bool cols_ok = col_span.whole ? true : ( (in_col1 <= in_col2) && (in_col2 < n_cols) );
  2587. return ( (rows_ok == true) && (cols_ok == true) );
  2588. }
  2589. template<typename eT>
  2590. arma_inline
  2591. arma_warn_unused
  2592. bool
  2593. SpMat<eT>::in_range(const uword in_row, const uword in_col, const SizeMat& s) const
  2594. {
  2595. const uword l_n_rows = n_rows;
  2596. const uword l_n_cols = n_cols;
  2597. if( (in_row >= l_n_rows) || (in_col >= l_n_cols) || ((in_row + s.n_rows) > l_n_rows) || ((in_col + s.n_cols) > l_n_cols) )
  2598. {
  2599. return false;
  2600. }
  2601. else
  2602. {
  2603. return true;
  2604. }
  2605. }
  2606. template<typename eT>
  2607. arma_cold
  2608. inline
  2609. void
  2610. SpMat<eT>::impl_print(const std::string& extra_text) const
  2611. {
  2612. arma_extra_debug_sigprint();
  2613. sync_csc();
  2614. if(extra_text.length() != 0)
  2615. {
  2616. const std::streamsize orig_width = get_cout_stream().width();
  2617. get_cout_stream() << extra_text << '\n';
  2618. get_cout_stream().width(orig_width);
  2619. }
  2620. arma_ostream::print(get_cout_stream(), *this, true);
  2621. }
  2622. template<typename eT>
  2623. arma_cold
  2624. inline
  2625. void
  2626. SpMat<eT>::impl_print(std::ostream& user_stream, const std::string& extra_text) const
  2627. {
  2628. arma_extra_debug_sigprint();
  2629. sync_csc();
  2630. if(extra_text.length() != 0)
  2631. {
  2632. const std::streamsize orig_width = user_stream.width();
  2633. user_stream << extra_text << '\n';
  2634. user_stream.width(orig_width);
  2635. }
  2636. arma_ostream::print(user_stream, *this, true);
  2637. }
  2638. template<typename eT>
  2639. arma_cold
  2640. inline
  2641. void
  2642. SpMat<eT>::impl_raw_print(const std::string& extra_text) const
  2643. {
  2644. arma_extra_debug_sigprint();
  2645. sync_csc();
  2646. if(extra_text.length() != 0)
  2647. {
  2648. const std::streamsize orig_width = get_cout_stream().width();
  2649. get_cout_stream() << extra_text << '\n';
  2650. get_cout_stream().width(orig_width);
  2651. }
  2652. arma_ostream::print(get_cout_stream(), *this, false);
  2653. }
  2654. template<typename eT>
  2655. arma_cold
  2656. inline
  2657. void
  2658. SpMat<eT>::impl_raw_print(std::ostream& user_stream, const std::string& extra_text) const
  2659. {
  2660. arma_extra_debug_sigprint();
  2661. sync_csc();
  2662. if(extra_text.length() != 0)
  2663. {
  2664. const std::streamsize orig_width = user_stream.width();
  2665. user_stream << extra_text << '\n';
  2666. user_stream.width(orig_width);
  2667. }
  2668. arma_ostream::print(user_stream, *this, false);
  2669. }
  2670. /**
  2671. * Matrix printing, prepends supplied text.
  2672. * Prints 0 wherever no element exists.
  2673. */
  2674. template<typename eT>
  2675. arma_cold
  2676. inline
  2677. void
  2678. SpMat<eT>::impl_print_dense(const std::string& extra_text) const
  2679. {
  2680. arma_extra_debug_sigprint();
  2681. sync_csc();
  2682. if(extra_text.length() != 0)
  2683. {
  2684. const std::streamsize orig_width = get_cout_stream().width();
  2685. get_cout_stream() << extra_text << '\n';
  2686. get_cout_stream().width(orig_width);
  2687. }
  2688. arma_ostream::print_dense(get_cout_stream(), *this, true);
  2689. }
  2690. template<typename eT>
  2691. arma_cold
  2692. inline
  2693. void
  2694. SpMat<eT>::impl_print_dense(std::ostream& user_stream, const std::string& extra_text) const
  2695. {
  2696. arma_extra_debug_sigprint();
  2697. sync_csc();
  2698. if(extra_text.length() != 0)
  2699. {
  2700. const std::streamsize orig_width = user_stream.width();
  2701. user_stream << extra_text << '\n';
  2702. user_stream.width(orig_width);
  2703. }
  2704. arma_ostream::print_dense(user_stream, *this, true);
  2705. }
  2706. template<typename eT>
  2707. arma_cold
  2708. inline
  2709. void
  2710. SpMat<eT>::impl_raw_print_dense(const std::string& extra_text) const
  2711. {
  2712. arma_extra_debug_sigprint();
  2713. sync_csc();
  2714. if(extra_text.length() != 0)
  2715. {
  2716. const std::streamsize orig_width = get_cout_stream().width();
  2717. get_cout_stream() << extra_text << '\n';
  2718. get_cout_stream().width(orig_width);
  2719. }
  2720. arma_ostream::print_dense(get_cout_stream(), *this, false);
  2721. }
  2722. template<typename eT>
  2723. arma_cold
  2724. inline
  2725. void
  2726. SpMat<eT>::impl_raw_print_dense(std::ostream& user_stream, const std::string& extra_text) const
  2727. {
  2728. arma_extra_debug_sigprint();
  2729. sync_csc();
  2730. if(extra_text.length() != 0)
  2731. {
  2732. const std::streamsize orig_width = user_stream.width();
  2733. user_stream << extra_text << '\n';
  2734. user_stream.width(orig_width);
  2735. }
  2736. arma_ostream::print_dense(user_stream, *this, false);
  2737. }
  2738. //! Set the size to the size of another matrix.
  2739. template<typename eT>
  2740. template<typename eT2>
  2741. inline
  2742. void
  2743. SpMat<eT>::copy_size(const SpMat<eT2>& m)
  2744. {
  2745. arma_extra_debug_sigprint();
  2746. set_size(m.n_rows, m.n_cols);
  2747. }
  2748. template<typename eT>
  2749. template<typename eT2>
  2750. inline
  2751. void
  2752. SpMat<eT>::copy_size(const Mat<eT2>& m)
  2753. {
  2754. arma_extra_debug_sigprint();
  2755. set_size(m.n_rows, m.n_cols);
  2756. }
  2757. template<typename eT>
  2758. inline
  2759. void
  2760. SpMat<eT>::set_size(const uword in_elem)
  2761. {
  2762. arma_extra_debug_sigprint();
  2763. // If this is a row vector, we resize to a row vector.
  2764. if(vec_state == 2)
  2765. {
  2766. set_size(1, in_elem);
  2767. }
  2768. else
  2769. {
  2770. set_size(in_elem, 1);
  2771. }
  2772. }
  2773. template<typename eT>
  2774. inline
  2775. void
  2776. SpMat<eT>::set_size(const uword in_rows, const uword in_cols)
  2777. {
  2778. arma_extra_debug_sigprint();
  2779. invalidate_cache(); // placed here, as set_size() is used during matrix modification
  2780. if( (n_rows == in_rows) && (n_cols == in_cols) )
  2781. {
  2782. return;
  2783. }
  2784. else
  2785. {
  2786. init(in_rows, in_cols);
  2787. }
  2788. }
  2789. template<typename eT>
  2790. inline
  2791. void
  2792. SpMat<eT>::set_size(const SizeMat& s)
  2793. {
  2794. arma_extra_debug_sigprint();
  2795. (*this).set_size(s.n_rows, s.n_cols);
  2796. }
  2797. template<typename eT>
  2798. inline
  2799. void
  2800. SpMat<eT>::resize(const uword in_rows, const uword in_cols)
  2801. {
  2802. arma_extra_debug_sigprint();
  2803. if( (n_rows == in_rows) && (n_cols == in_cols) )
  2804. {
  2805. return;
  2806. }
  2807. if( (n_elem == 0) || (n_nonzero == 0) )
  2808. {
  2809. set_size(in_rows, in_cols);
  2810. return;
  2811. }
  2812. SpMat<eT> tmp(in_rows, in_cols);
  2813. if(tmp.n_elem > 0)
  2814. {
  2815. sync_csc();
  2816. const uword last_row = (std::min)(in_rows, n_rows) - 1;
  2817. const uword last_col = (std::min)(in_cols, n_cols) - 1;
  2818. tmp.submat(0, 0, last_row, last_col) = (*this).submat(0, 0, last_row, last_col);
  2819. }
  2820. steal_mem(tmp);
  2821. }
  2822. template<typename eT>
  2823. inline
  2824. void
  2825. SpMat<eT>::resize(const SizeMat& s)
  2826. {
  2827. arma_extra_debug_sigprint();
  2828. (*this).resize(s.n_rows, s.n_cols);
  2829. }
  2830. template<typename eT>
  2831. inline
  2832. void
  2833. SpMat<eT>::reshape(const uword in_rows, const uword in_cols)
  2834. {
  2835. arma_extra_debug_sigprint();
  2836. arma_check( ((in_rows*in_cols) != n_elem), "SpMat::reshape(): changing the number of elements in a sparse matrix is currently not supported" );
  2837. if( (n_rows == in_rows) && (n_cols == in_cols) ) { return; }
  2838. if(vec_state == 1) { arma_debug_check( (in_cols != 1), "SpMat::reshape(): object is a column vector; requested size is not compatible" ); }
  2839. if(vec_state == 2) { arma_debug_check( (in_rows != 1), "SpMat::reshape(): object is a row vector; requested size is not compatible" ); }
  2840. if(n_nonzero == 0)
  2841. {
  2842. (*this).zeros(in_rows, in_cols);
  2843. return;
  2844. }
  2845. if(in_cols == 1)
  2846. {
  2847. (*this).reshape_helper_intovec();
  2848. }
  2849. else
  2850. {
  2851. (*this).reshape_helper_generic(in_rows, in_cols);
  2852. }
  2853. }
  2854. template<typename eT>
  2855. inline
  2856. void
  2857. SpMat<eT>::reshape(const SizeMat& s)
  2858. {
  2859. arma_extra_debug_sigprint();
  2860. (*this).reshape(s.n_rows, s.n_cols);
  2861. }
  2862. template<typename eT>
  2863. inline
  2864. void
  2865. SpMat<eT>::reshape_helper_generic(const uword in_rows, const uword in_cols)
  2866. {
  2867. arma_extra_debug_sigprint();
  2868. sync_csc();
  2869. invalidate_cache();
  2870. // We have to modify all of the relevant row indices and the relevant column pointers.
  2871. // Iterate over all the points to do this. We won't be deleting any points, but we will be modifying
  2872. // columns and rows. We'll have to store a new set of column vectors.
  2873. uword* new_col_ptrs = memory::acquire<uword>(in_cols + 2);
  2874. new_col_ptrs[in_cols + 1] = std::numeric_limits<uword>::max();
  2875. uword* new_row_indices = memory::acquire<uword>(n_nonzero + 1);
  2876. access::rw(new_row_indices[n_nonzero]) = 0;
  2877. arrayops::fill_zeros(new_col_ptrs, in_cols + 1);
  2878. const_iterator it = begin();
  2879. const_iterator it_end = end();
  2880. for(; it != it_end; ++it)
  2881. {
  2882. uword vector_position = (it.col() * n_rows) + it.row();
  2883. new_row_indices[it.pos()] = vector_position % in_rows;
  2884. ++new_col_ptrs[vector_position / in_rows + 1];
  2885. }
  2886. // Now sum the column counts to get the new column pointers.
  2887. for(uword i = 1; i <= in_cols; i++)
  2888. {
  2889. access::rw(new_col_ptrs[i]) += new_col_ptrs[i - 1];
  2890. }
  2891. // Copy the new row indices.
  2892. if(row_indices) { memory::release(access::rw(row_indices)); }
  2893. if(col_ptrs) { memory::release(access::rw(col_ptrs)); }
  2894. access::rw(row_indices) = new_row_indices;
  2895. access::rw(col_ptrs) = new_col_ptrs;
  2896. // Now set the size.
  2897. access::rw(n_rows) = in_rows;
  2898. access::rw(n_cols) = in_cols;
  2899. }
  2900. template<typename eT>
  2901. inline
  2902. void
  2903. SpMat<eT>::reshape_helper_intovec()
  2904. {
  2905. arma_extra_debug_sigprint();
  2906. sync_csc();
  2907. invalidate_cache();
  2908. const_iterator it = begin();
  2909. const uword t_n_rows = n_rows;
  2910. const uword t_n_nonzero = n_nonzero;
  2911. for(uword i=0; i < t_n_nonzero; ++i)
  2912. {
  2913. const uword t_index = (it.col() * t_n_rows) + it.row();
  2914. // ensure the iterator is pointing to the next element
  2915. // before we overwrite the row index of the current element
  2916. ++it;
  2917. access::rw(row_indices[i]) = t_index;
  2918. }
  2919. access::rw(row_indices[n_nonzero]) = 0;
  2920. access::rw(col_ptrs[0]) = 0;
  2921. access::rw(col_ptrs[1]) = n_nonzero;
  2922. access::rw(col_ptrs[2]) = std::numeric_limits<uword>::max();
  2923. access::rw(n_rows) = (n_rows * n_cols);
  2924. access::rw(n_cols) = 1;
  2925. }
  2926. //! NOTE: don't use this form; it's deprecated and will be removed
  2927. template<typename eT>
  2928. arma_deprecated
  2929. inline
  2930. void
  2931. SpMat<eT>::reshape(const uword in_rows, const uword in_cols, const uword dim)
  2932. {
  2933. arma_extra_debug_sigprint();
  2934. arma_debug_check( (dim > 1), "SpMat::reshape(): parameter 'dim' must be 0 or 1" );
  2935. if(dim == 0)
  2936. {
  2937. (*this).reshape(in_rows, in_cols);
  2938. }
  2939. else
  2940. if(dim == 1)
  2941. {
  2942. arma_check( ((in_rows*in_cols) != n_elem), "SpMat::reshape(): changing the number of elements in a sparse matrix is currently not supported" );
  2943. sync_csc();
  2944. // Row-wise reshaping. This is more tedious and we will use a separate sparse matrix to do it.
  2945. SpMat<eT> tmp(in_rows, in_cols);
  2946. for(const_row_iterator it = begin_row(); it.pos() < n_nonzero; ++it)
  2947. {
  2948. uword vector_position = (it.row() * n_cols) + it.col();
  2949. tmp((vector_position / in_cols), (vector_position % in_cols)) = (*it);
  2950. }
  2951. steal_mem(tmp);
  2952. }
  2953. }
  2954. //! apply a functor to each non-zero element
  2955. template<typename eT>
  2956. template<typename functor>
  2957. inline
  2958. const SpMat<eT>&
  2959. SpMat<eT>::for_each(functor F)
  2960. {
  2961. arma_extra_debug_sigprint();
  2962. sync_csc();
  2963. const uword N = (*this).n_nonzero;
  2964. eT* rw_values = access::rwp(values);
  2965. bool modified = false;
  2966. bool has_zero = false;
  2967. for(uword i=0; i < N; ++i)
  2968. {
  2969. eT& new_value = rw_values[i];
  2970. const eT old_value = new_value;
  2971. F(new_value);
  2972. if(new_value != old_value) { modified = true; }
  2973. if(new_value == eT(0) ) { has_zero = true; }
  2974. }
  2975. if(modified) { invalidate_cache(); }
  2976. if(has_zero) { remove_zeros(); }
  2977. return *this;
  2978. }
  2979. template<typename eT>
  2980. template<typename functor>
  2981. inline
  2982. const SpMat<eT>&
  2983. SpMat<eT>::for_each(functor F) const
  2984. {
  2985. arma_extra_debug_sigprint();
  2986. sync_csc();
  2987. const uword N = (*this).n_nonzero;
  2988. for(uword i=0; i < N; ++i)
  2989. {
  2990. F(values[i]);
  2991. }
  2992. return *this;
  2993. }
  2994. //! transform each non-zero element using a functor
  2995. template<typename eT>
  2996. template<typename functor>
  2997. inline
  2998. const SpMat<eT>&
  2999. SpMat<eT>::transform(functor F)
  3000. {
  3001. arma_extra_debug_sigprint();
  3002. sync_csc();
  3003. invalidate_cache();
  3004. const uword N = (*this).n_nonzero;
  3005. eT* rw_values = access::rwp(values);
  3006. bool has_zero = false;
  3007. for(uword i=0; i < N; ++i)
  3008. {
  3009. eT& rw_values_i = rw_values[i];
  3010. rw_values_i = eT( F(rw_values_i) );
  3011. if(rw_values_i == eT(0)) { has_zero = true; }
  3012. }
  3013. if(has_zero) { remove_zeros(); }
  3014. return *this;
  3015. }
  3016. template<typename eT>
  3017. inline
  3018. const SpMat<eT>&
  3019. SpMat<eT>::replace(const eT old_val, const eT new_val)
  3020. {
  3021. arma_extra_debug_sigprint();
  3022. if(old_val == eT(0))
  3023. {
  3024. arma_debug_warn("SpMat::replace(): replacement not done, as old_val = 0");
  3025. }
  3026. else
  3027. {
  3028. sync_csc();
  3029. invalidate_cache();
  3030. arrayops::replace(access::rwp(values), n_nonzero, old_val, new_val);
  3031. if(new_val == eT(0)) { remove_zeros(); }
  3032. }
  3033. return *this;
  3034. }
  3035. template<typename eT>
  3036. inline
  3037. const SpMat<eT>&
  3038. SpMat<eT>::clean(const typename get_pod_type<eT>::result threshold)
  3039. {
  3040. arma_extra_debug_sigprint();
  3041. if(n_nonzero == 0) { return *this; }
  3042. sync_csc();
  3043. invalidate_cache();
  3044. arrayops::clean(access::rwp(values), n_nonzero, threshold);
  3045. remove_zeros();
  3046. return *this;
  3047. }
  3048. template<typename eT>
  3049. inline
  3050. const SpMat<eT>&
  3051. SpMat<eT>::zeros()
  3052. {
  3053. arma_extra_debug_sigprint();
  3054. const bool already_done = ( (sync_state != 1) && (n_nonzero == 0) );
  3055. if(already_done == false)
  3056. {
  3057. init(n_rows, n_cols);
  3058. }
  3059. return *this;
  3060. }
  3061. template<typename eT>
  3062. inline
  3063. const SpMat<eT>&
  3064. SpMat<eT>::zeros(const uword in_elem)
  3065. {
  3066. arma_extra_debug_sigprint();
  3067. if(vec_state == 2)
  3068. {
  3069. zeros(1, in_elem); // Row vector
  3070. }
  3071. else
  3072. {
  3073. zeros(in_elem, 1);
  3074. }
  3075. return *this;
  3076. }
  3077. template<typename eT>
  3078. inline
  3079. const SpMat<eT>&
  3080. SpMat<eT>::zeros(const uword in_rows, const uword in_cols)
  3081. {
  3082. arma_extra_debug_sigprint();
  3083. const bool already_done = ( (sync_state != 1) && (n_nonzero == 0) && (n_rows == in_rows) && (n_cols == in_cols) );
  3084. if(already_done == false)
  3085. {
  3086. init(in_rows, in_cols);
  3087. }
  3088. return *this;
  3089. }
  3090. template<typename eT>
  3091. inline
  3092. const SpMat<eT>&
  3093. SpMat<eT>::zeros(const SizeMat& s)
  3094. {
  3095. arma_extra_debug_sigprint();
  3096. return (*this).zeros(s.n_rows, s.n_cols);
  3097. }
  3098. template<typename eT>
  3099. inline
  3100. const SpMat<eT>&
  3101. SpMat<eT>::eye()
  3102. {
  3103. arma_extra_debug_sigprint();
  3104. return (*this).eye(n_rows, n_cols);
  3105. }
  3106. template<typename eT>
  3107. inline
  3108. const SpMat<eT>&
  3109. SpMat<eT>::eye(const uword in_rows, const uword in_cols)
  3110. {
  3111. arma_extra_debug_sigprint();
  3112. const uword N = (std::min)(in_rows, in_cols);
  3113. init(in_rows, in_cols, N);
  3114. arrayops::inplace_set(access::rwp(values), eT(1), N);
  3115. for(uword i = 0; i < N; ++i) { access::rw(row_indices[i]) = i; }
  3116. for(uword i = 0; i <= N; ++i) { access::rw(col_ptrs[i]) = i; }
  3117. // take into account non-square matrices
  3118. for(uword i = (N+1); i <= in_cols; ++i) { access::rw(col_ptrs[i]) = N; }
  3119. access::rw(n_nonzero) = N;
  3120. return *this;
  3121. }
  3122. template<typename eT>
  3123. inline
  3124. const SpMat<eT>&
  3125. SpMat<eT>::eye(const SizeMat& s)
  3126. {
  3127. arma_extra_debug_sigprint();
  3128. return (*this).eye(s.n_rows, s.n_cols);
  3129. }
  3130. template<typename eT>
  3131. inline
  3132. const SpMat<eT>&
  3133. SpMat<eT>::speye()
  3134. {
  3135. arma_extra_debug_sigprint();
  3136. return (*this).eye(n_rows, n_cols);
  3137. }
  3138. template<typename eT>
  3139. inline
  3140. const SpMat<eT>&
  3141. SpMat<eT>::speye(const uword in_n_rows, const uword in_n_cols)
  3142. {
  3143. arma_extra_debug_sigprint();
  3144. return (*this).eye(in_n_rows, in_n_cols);
  3145. }
  3146. template<typename eT>
  3147. inline
  3148. const SpMat<eT>&
  3149. SpMat<eT>::speye(const SizeMat& s)
  3150. {
  3151. arma_extra_debug_sigprint();
  3152. return (*this).eye(s.n_rows, s.n_cols);
  3153. }
  3154. template<typename eT>
  3155. inline
  3156. const SpMat<eT>&
  3157. SpMat<eT>::sprandu(const uword in_rows, const uword in_cols, const double density)
  3158. {
  3159. arma_extra_debug_sigprint();
  3160. arma_debug_check( ( (density < double(0)) || (density > double(1)) ), "sprandu(): density must be in the [0,1] interval" );
  3161. const uword new_n_nonzero = uword(density * double(in_rows) * double(in_cols) + 0.5);
  3162. init(in_rows, in_cols, new_n_nonzero);
  3163. if(new_n_nonzero == 0) { return *this; }
  3164. arma_rng::randu<eT>::fill( access::rwp(values), new_n_nonzero );
  3165. uvec indices = linspace<uvec>( 0u, in_rows*in_cols-1, new_n_nonzero );
  3166. // perturb the indices
  3167. for(uword i=1; i < new_n_nonzero-1; ++i)
  3168. {
  3169. const uword index_left = indices[i-1];
  3170. const uword index_right = indices[i+1];
  3171. const uword center = (index_left + index_right) / 2;
  3172. const uword delta1 = center - index_left - 1;
  3173. const uword delta2 = index_right - center - 1;
  3174. const uword min_delta = (std::min)(delta1, delta2);
  3175. uword index_new = uword( double(center) + double(min_delta) * (2.0*randu()-1.0) );
  3176. // paranoia, but better be safe than sorry
  3177. if( (index_left < index_new) && (index_new < index_right) )
  3178. {
  3179. indices[i] = index_new;
  3180. }
  3181. }
  3182. uword cur_index = 0;
  3183. uword count = 0;
  3184. for(uword lcol = 0; lcol < in_cols; ++lcol)
  3185. for(uword lrow = 0; lrow < in_rows; ++lrow)
  3186. {
  3187. if(count == indices[cur_index])
  3188. {
  3189. access::rw(row_indices[cur_index]) = lrow;
  3190. access::rw(col_ptrs[lcol + 1])++;
  3191. ++cur_index;
  3192. }
  3193. ++count;
  3194. }
  3195. if(cur_index != new_n_nonzero)
  3196. {
  3197. // Fix size to correct size.
  3198. mem_resize(cur_index);
  3199. }
  3200. // Sum column pointers.
  3201. for(uword lcol = 1; lcol <= in_cols; ++lcol)
  3202. {
  3203. access::rw(col_ptrs[lcol]) += col_ptrs[lcol - 1];
  3204. }
  3205. return *this;
  3206. }
  3207. template<typename eT>
  3208. inline
  3209. const SpMat<eT>&
  3210. SpMat<eT>::sprandu(const SizeMat& s, const double density)
  3211. {
  3212. arma_extra_debug_sigprint();
  3213. return (*this).sprandu(s.n_rows, s.n_cols, density);
  3214. }
  3215. template<typename eT>
  3216. inline
  3217. const SpMat<eT>&
  3218. SpMat<eT>::sprandn(const uword in_rows, const uword in_cols, const double density)
  3219. {
  3220. arma_extra_debug_sigprint();
  3221. arma_debug_check( ( (density < double(0)) || (density > double(1)) ), "sprandn(): density must be in the [0,1] interval" );
  3222. const uword new_n_nonzero = uword(density * double(in_rows) * double(in_cols) + 0.5);
  3223. init(in_rows, in_cols, new_n_nonzero);
  3224. if(new_n_nonzero == 0) { return *this; }
  3225. arma_rng::randn<eT>::fill( access::rwp(values), new_n_nonzero );
  3226. uvec indices = linspace<uvec>( 0u, in_rows*in_cols-1, new_n_nonzero );
  3227. // perturb the indices
  3228. for(uword i=1; i < new_n_nonzero-1; ++i)
  3229. {
  3230. const uword index_left = indices[i-1];
  3231. const uword index_right = indices[i+1];
  3232. const uword center = (index_left + index_right) / 2;
  3233. const uword delta1 = center - index_left - 1;
  3234. const uword delta2 = index_right - center - 1;
  3235. const uword min_delta = (std::min)(delta1, delta2);
  3236. uword index_new = uword( double(center) + double(min_delta) * (2.0*randu()-1.0) );
  3237. // paranoia, but better be safe than sorry
  3238. if( (index_left < index_new) && (index_new < index_right) )
  3239. {
  3240. indices[i] = index_new;
  3241. }
  3242. }
  3243. uword cur_index = 0;
  3244. uword count = 0;
  3245. for(uword lcol = 0; lcol < in_cols; ++lcol)
  3246. for(uword lrow = 0; lrow < in_rows; ++lrow)
  3247. {
  3248. if(count == indices[cur_index])
  3249. {
  3250. access::rw(row_indices[cur_index]) = lrow;
  3251. access::rw(col_ptrs[lcol + 1])++;
  3252. ++cur_index;
  3253. }
  3254. ++count;
  3255. }
  3256. if(cur_index != new_n_nonzero)
  3257. {
  3258. // Fix size to correct size.
  3259. mem_resize(cur_index);
  3260. }
  3261. // Sum column pointers.
  3262. for(uword lcol = 1; lcol <= in_cols; ++lcol)
  3263. {
  3264. access::rw(col_ptrs[lcol]) += col_ptrs[lcol - 1];
  3265. }
  3266. return *this;
  3267. }
  3268. template<typename eT>
  3269. inline
  3270. const SpMat<eT>&
  3271. SpMat<eT>::sprandn(const SizeMat& s, const double density)
  3272. {
  3273. arma_extra_debug_sigprint();
  3274. return (*this).sprandn(s.n_rows, s.n_cols, density);
  3275. }
  3276. template<typename eT>
  3277. inline
  3278. void
  3279. SpMat<eT>::reset()
  3280. {
  3281. arma_extra_debug_sigprint();
  3282. switch(vec_state)
  3283. {
  3284. default:
  3285. init(0, 0);
  3286. break;
  3287. case 1:
  3288. init(0, 1);
  3289. break;
  3290. case 2:
  3291. init(1, 0);
  3292. break;
  3293. }
  3294. }
  3295. template<typename eT>
  3296. inline
  3297. void
  3298. SpMat<eT>::reserve(const uword in_rows, const uword in_cols, const uword new_n_nonzero)
  3299. {
  3300. arma_extra_debug_sigprint();
  3301. init(in_rows, in_cols, new_n_nonzero);
  3302. }
  3303. template<typename eT>
  3304. template<typename T1>
  3305. inline
  3306. void
  3307. SpMat<eT>::set_real(const SpBase<typename SpMat<eT>::pod_type,T1>& X)
  3308. {
  3309. arma_extra_debug_sigprint();
  3310. SpMat_aux::set_real(*this, X);
  3311. }
  3312. template<typename eT>
  3313. template<typename T1>
  3314. inline
  3315. void
  3316. SpMat<eT>::set_imag(const SpBase<typename SpMat<eT>::pod_type,T1>& X)
  3317. {
  3318. arma_extra_debug_sigprint();
  3319. SpMat_aux::set_imag(*this, X);
  3320. }
  3321. //! save the matrix to a file
  3322. template<typename eT>
  3323. inline
  3324. arma_cold
  3325. bool
  3326. SpMat<eT>::save(const std::string name, const file_type type, const bool print_status) const
  3327. {
  3328. arma_extra_debug_sigprint();
  3329. sync_csc();
  3330. bool save_okay;
  3331. switch(type)
  3332. {
  3333. case csv_ascii:
  3334. return (*this).save(csv_name(name), type, print_status);
  3335. break;
  3336. case arma_binary:
  3337. save_okay = diskio::save_arma_binary(*this, name);
  3338. break;
  3339. case coord_ascii:
  3340. save_okay = diskio::save_coord_ascii(*this, name);
  3341. break;
  3342. default:
  3343. if(print_status) { arma_debug_warn("SpMat::save(): unsupported file type"); }
  3344. save_okay = false;
  3345. }
  3346. if(print_status && (save_okay == false)) { arma_debug_warn("SpMat::save(): couldn't write to ", name); }
  3347. return save_okay;
  3348. }
  3349. template<typename eT>
  3350. inline
  3351. arma_cold
  3352. bool
  3353. SpMat<eT>::save(const csv_name& spec, const file_type type, const bool print_status) const
  3354. {
  3355. arma_extra_debug_sigprint();
  3356. if(type != csv_ascii)
  3357. {
  3358. arma_debug_check(true, "SpMat::save(): unsupported file type for csv_name()");
  3359. return false;
  3360. }
  3361. const bool do_trans = bool(spec.opts.flags & csv_opts::flag_trans );
  3362. const bool no_header = bool(spec.opts.flags & csv_opts::flag_no_header );
  3363. bool with_header = bool(spec.opts.flags & csv_opts::flag_with_header);
  3364. arma_extra_debug_print("SpMat::save(csv_name): enabled flags:");
  3365. if(do_trans ) { arma_extra_debug_print("trans"); }
  3366. if(no_header ) { arma_extra_debug_print("no_header"); }
  3367. if(with_header) { arma_extra_debug_print("with_header"); }
  3368. if(no_header) { with_header = false; }
  3369. if(with_header)
  3370. {
  3371. if( (spec.header_ro.n_cols != 1) && (spec.header_ro.n_rows != 1) )
  3372. {
  3373. if(print_status) { arma_debug_warn("SpMat::save(): given header must have a vector layout"); }
  3374. return false;
  3375. }
  3376. for(uword i=0; i < spec.header_ro.n_elem; ++i)
  3377. {
  3378. const std::string& token = spec.header_ro.at(i);
  3379. if(token.find(',') != std::string::npos)
  3380. {
  3381. if(print_status) { arma_debug_warn("SpMat::save(): token within the header contains a comma: '", token, "'"); }
  3382. return false;
  3383. }
  3384. }
  3385. const uword save_n_cols = (do_trans) ? (*this).n_rows : (*this).n_cols;
  3386. if(spec.header_ro.n_elem != save_n_cols)
  3387. {
  3388. if(print_status) { arma_debug_warn("SpMat::save(): size mistmach between header and matrix"); }
  3389. return false;
  3390. }
  3391. }
  3392. bool save_okay = false;
  3393. if(do_trans)
  3394. {
  3395. const SpMat<eT> tmp = (*this).st();
  3396. save_okay = diskio::save_csv_ascii(tmp, spec.filename, spec.header_ro, with_header);
  3397. }
  3398. else
  3399. {
  3400. save_okay = diskio::save_csv_ascii(*this, spec.filename, spec.header_ro, with_header);
  3401. }
  3402. if((print_status == true) && (save_okay == false))
  3403. {
  3404. arma_debug_warn("SpMat::save(): couldn't write to ", spec.filename);
  3405. }
  3406. return save_okay;
  3407. }
  3408. //! save the matrix to a stream
  3409. template<typename eT>
  3410. inline
  3411. arma_cold
  3412. bool
  3413. SpMat<eT>::save(std::ostream& os, const file_type type, const bool print_status) const
  3414. {
  3415. arma_extra_debug_sigprint();
  3416. sync_csc();
  3417. bool save_okay;
  3418. switch(type)
  3419. {
  3420. case csv_ascii:
  3421. save_okay = diskio::save_csv_ascii(*this, os);
  3422. break;
  3423. case arma_binary:
  3424. save_okay = diskio::save_arma_binary(*this, os);
  3425. break;
  3426. case coord_ascii:
  3427. save_okay = diskio::save_coord_ascii(*this, os);
  3428. break;
  3429. default:
  3430. if(print_status) { arma_debug_warn("SpMat::save(): unsupported file type"); }
  3431. save_okay = false;
  3432. }
  3433. if(print_status && (save_okay == false)) { arma_debug_warn("SpMat::save(): couldn't write to the given stream"); }
  3434. return save_okay;
  3435. }
  3436. //! load a matrix from a file
  3437. template<typename eT>
  3438. inline
  3439. arma_cold
  3440. bool
  3441. SpMat<eT>::load(const std::string name, const file_type type, const bool print_status)
  3442. {
  3443. arma_extra_debug_sigprint();
  3444. invalidate_cache();
  3445. bool load_okay;
  3446. std::string err_msg;
  3447. switch(type)
  3448. {
  3449. // case auto_detect:
  3450. // load_okay = diskio::load_auto_detect(*this, name, err_msg);
  3451. // break;
  3452. case csv_ascii:
  3453. return (*this).load(csv_name(name), type, print_status);
  3454. break;
  3455. case arma_binary:
  3456. load_okay = diskio::load_arma_binary(*this, name, err_msg);
  3457. break;
  3458. case coord_ascii:
  3459. load_okay = diskio::load_coord_ascii(*this, name, err_msg);
  3460. break;
  3461. default:
  3462. if(print_status) { arma_debug_warn("SpMat::load(): unsupported file type"); }
  3463. load_okay = false;
  3464. }
  3465. if(print_status && (load_okay == false))
  3466. {
  3467. if(err_msg.length() > 0)
  3468. {
  3469. arma_debug_warn("SpMat::load(): ", err_msg, name);
  3470. }
  3471. else
  3472. {
  3473. arma_debug_warn("SpMat::load(): couldn't read ", name);
  3474. }
  3475. }
  3476. if(load_okay == false)
  3477. {
  3478. (*this).reset();
  3479. }
  3480. return load_okay;
  3481. }
  3482. template<typename eT>
  3483. inline
  3484. arma_cold
  3485. bool
  3486. SpMat<eT>::load(const csv_name& spec, const file_type type, const bool print_status)
  3487. {
  3488. arma_extra_debug_sigprint();
  3489. if(type != csv_ascii)
  3490. {
  3491. arma_debug_check(true, "SpMat::load(): unsupported file type for csv_name()");
  3492. return false;
  3493. }
  3494. const bool do_trans = bool(spec.opts.flags & csv_opts::flag_trans );
  3495. const bool no_header = bool(spec.opts.flags & csv_opts::flag_no_header );
  3496. bool with_header = bool(spec.opts.flags & csv_opts::flag_with_header);
  3497. arma_extra_debug_print("SpMat::load(csv_name): enabled flags:");
  3498. if(do_trans ) { arma_extra_debug_print("trans"); }
  3499. if(no_header ) { arma_extra_debug_print("no_header"); }
  3500. if(with_header) { arma_extra_debug_print("with_header"); }
  3501. if(no_header) { with_header = false; }
  3502. bool load_okay = false;
  3503. std::string err_msg;
  3504. if(do_trans)
  3505. {
  3506. SpMat<eT> tmp_mat;
  3507. load_okay = diskio::load_csv_ascii(tmp_mat, spec.filename, err_msg, spec.header_rw, with_header);
  3508. if(load_okay)
  3509. {
  3510. (*this) = tmp_mat.st();
  3511. if(with_header)
  3512. {
  3513. // field::set_size() preserves data if the number of elements hasn't changed
  3514. spec.header_rw.set_size(spec.header_rw.n_elem, 1);
  3515. }
  3516. }
  3517. }
  3518. else
  3519. {
  3520. load_okay = diskio::load_csv_ascii(*this, spec.filename, err_msg, spec.header_rw, with_header);
  3521. }
  3522. if(print_status == true)
  3523. {
  3524. if(load_okay == false)
  3525. {
  3526. if(err_msg.length() > 0)
  3527. {
  3528. arma_debug_warn("SpMat::load(): ", err_msg, spec.filename);
  3529. }
  3530. else
  3531. {
  3532. arma_debug_warn("SpMat::load(): couldn't read ", spec.filename);
  3533. }
  3534. }
  3535. else
  3536. {
  3537. const uword load_n_cols = (do_trans) ? (*this).n_rows : (*this).n_cols;
  3538. if(with_header && (spec.header_rw.n_elem != load_n_cols))
  3539. {
  3540. arma_debug_warn("SpMat::load(): size mistmach between header and matrix");
  3541. }
  3542. }
  3543. }
  3544. if(load_okay == false)
  3545. {
  3546. (*this).reset();
  3547. if(with_header) { spec.header_rw.reset(); }
  3548. }
  3549. return load_okay;
  3550. }
  3551. //! load a matrix from a stream
  3552. template<typename eT>
  3553. inline
  3554. arma_cold
  3555. bool
  3556. SpMat<eT>::load(std::istream& is, const file_type type, const bool print_status)
  3557. {
  3558. arma_extra_debug_sigprint();
  3559. invalidate_cache();
  3560. bool load_okay;
  3561. std::string err_msg;
  3562. switch(type)
  3563. {
  3564. // case auto_detect:
  3565. // load_okay = diskio::load_auto_detect(*this, is, err_msg);
  3566. // break;
  3567. case csv_ascii:
  3568. load_okay = diskio::load_csv_ascii(*this, is, err_msg);
  3569. break;
  3570. case arma_binary:
  3571. load_okay = diskio::load_arma_binary(*this, is, err_msg);
  3572. break;
  3573. case coord_ascii:
  3574. load_okay = diskio::load_coord_ascii(*this, is, err_msg);
  3575. break;
  3576. default:
  3577. if(print_status) { arma_debug_warn("SpMat::load(): unsupported file type"); }
  3578. load_okay = false;
  3579. }
  3580. if(print_status && (load_okay == false))
  3581. {
  3582. if(err_msg.length() > 0)
  3583. {
  3584. arma_debug_warn("SpMat::load(): ", err_msg, "the given stream");
  3585. }
  3586. else
  3587. {
  3588. arma_debug_warn("SpMat::load(): couldn't load from the given stream");
  3589. }
  3590. }
  3591. if(load_okay == false)
  3592. {
  3593. (*this).reset();
  3594. }
  3595. return load_okay;
  3596. }
  3597. //! save the matrix to a file, without printing any error messages
  3598. template<typename eT>
  3599. inline
  3600. arma_cold
  3601. bool
  3602. SpMat<eT>::quiet_save(const std::string name, const file_type type) const
  3603. {
  3604. arma_extra_debug_sigprint();
  3605. return (*this).save(name, type, false);
  3606. }
  3607. //! save the matrix to a stream, without printing any error messages
  3608. template<typename eT>
  3609. inline
  3610. arma_cold
  3611. bool
  3612. SpMat<eT>::quiet_save(std::ostream& os, const file_type type) const
  3613. {
  3614. arma_extra_debug_sigprint();
  3615. return (*this).save(os, type, false);
  3616. }
  3617. //! load a matrix from a file, without printing any error messages
  3618. template<typename eT>
  3619. inline
  3620. arma_cold
  3621. bool
  3622. SpMat<eT>::quiet_load(const std::string name, const file_type type)
  3623. {
  3624. arma_extra_debug_sigprint();
  3625. return (*this).load(name, type, false);
  3626. }
  3627. //! load a matrix from a stream, without printing any error messages
  3628. template<typename eT>
  3629. inline
  3630. arma_cold
  3631. bool
  3632. SpMat<eT>::quiet_load(std::istream& is, const file_type type)
  3633. {
  3634. arma_extra_debug_sigprint();
  3635. return (*this).load(is, type, false);
  3636. }
  3637. /**
  3638. * Initialize the matrix to the specified size. Data is not preserved, so the matrix is assumed to be entirely sparse (empty).
  3639. */
  3640. template<typename eT>
  3641. inline
  3642. void
  3643. SpMat<eT>::init(uword in_rows, uword in_cols, const uword new_n_nonzero)
  3644. {
  3645. arma_extra_debug_sigprint();
  3646. invalidate_cache(); // placed here, as init() is used during matrix modification
  3647. // Clean out the existing memory.
  3648. if(values ) { memory::release(access::rw(values)); }
  3649. if(row_indices) { memory::release(access::rw(row_indices)); }
  3650. if(col_ptrs ) { memory::release(access::rw(col_ptrs)); }
  3651. init_cold(in_rows, in_cols, new_n_nonzero);
  3652. }
  3653. template<typename eT>
  3654. inline
  3655. void
  3656. arma_cold
  3657. SpMat<eT>::init_cold(uword in_rows, uword in_cols, const uword new_n_nonzero)
  3658. {
  3659. arma_extra_debug_sigprint();
  3660. // Verify that we are allowed to do this.
  3661. if(vec_state > 0)
  3662. {
  3663. if((in_rows == 0) && (in_cols == 0))
  3664. {
  3665. if(vec_state == 1) { in_cols = 1; }
  3666. if(vec_state == 2) { in_rows = 1; }
  3667. }
  3668. else
  3669. {
  3670. if(vec_state == 1) { arma_debug_check( (in_cols != 1), "SpMat::init(): object is a column vector; requested size is not compatible" ); }
  3671. if(vec_state == 2) { arma_debug_check( (in_rows != 1), "SpMat::init(): object is a row vector; requested size is not compatible" ); }
  3672. }
  3673. }
  3674. #if defined(ARMA_64BIT_WORD)
  3675. const char* error_message = "SpMat::init(): requested size is too large";
  3676. #else
  3677. const char* error_message = "SpMat::init(): requested size is too large; suggest to compile in C++11 mode and/or enable ARMA_64BIT_WORD";
  3678. #endif
  3679. // Ensure that n_elem can hold the result of (n_rows * n_cols)
  3680. arma_debug_check
  3681. (
  3682. (
  3683. ( (in_rows > ARMA_MAX_UHWORD) || (in_cols > ARMA_MAX_UHWORD) )
  3684. ? ( (double(in_rows) * double(in_cols)) > double(ARMA_MAX_UWORD) )
  3685. : false
  3686. ),
  3687. error_message
  3688. );
  3689. access::rw(col_ptrs) = memory::acquire<uword>(in_cols + 2);
  3690. access::rw(values) = memory::acquire<eT> (new_n_nonzero + 1);
  3691. access::rw(row_indices) = memory::acquire<uword>(new_n_nonzero + 1);
  3692. // fill column pointers with 0,
  3693. // except for the last element which contains the maximum possible element
  3694. // (so iterators terminate correctly).
  3695. arrayops::fill_zeros(access::rwp(col_ptrs), in_cols + 1);
  3696. access::rw(col_ptrs[in_cols + 1]) = std::numeric_limits<uword>::max();
  3697. access::rw( values[new_n_nonzero]) = 0;
  3698. access::rw(row_indices[new_n_nonzero]) = 0;
  3699. // Set the new size accordingly.
  3700. access::rw(n_rows) = in_rows;
  3701. access::rw(n_cols) = in_cols;
  3702. access::rw(n_elem) = (in_rows * in_cols);
  3703. access::rw(n_nonzero) = new_n_nonzero;
  3704. }
  3705. template<typename eT>
  3706. inline
  3707. void
  3708. SpMat<eT>::init(const std::string& text)
  3709. {
  3710. arma_extra_debug_sigprint();
  3711. Mat<eT> tmp(text);
  3712. if(vec_state == 1)
  3713. {
  3714. if((tmp.n_elem > 0) && tmp.is_vec())
  3715. {
  3716. access::rw(tmp.n_rows) = tmp.n_elem;
  3717. access::rw(tmp.n_cols) = 1;
  3718. }
  3719. }
  3720. if(vec_state == 2)
  3721. {
  3722. if((tmp.n_elem > 0) && tmp.is_vec())
  3723. {
  3724. access::rw(tmp.n_rows) = 1;
  3725. access::rw(tmp.n_cols) = tmp.n_elem;
  3726. }
  3727. }
  3728. (*this).operator=(tmp);
  3729. }
  3730. template<typename eT>
  3731. inline
  3732. void
  3733. SpMat<eT>::init(const SpMat<eT>& x)
  3734. {
  3735. arma_extra_debug_sigprint();
  3736. if(this == &x) { return; }
  3737. bool init_done = false;
  3738. #if defined(ARMA_USE_OPENMP)
  3739. if(x.sync_state == 1)
  3740. {
  3741. #pragma omp critical (arma_SpMat_init)
  3742. if(x.sync_state == 1)
  3743. {
  3744. (*this).init(x.cache);
  3745. init_done = true;
  3746. }
  3747. }
  3748. #elif (defined(ARMA_USE_CXX11) && !defined(ARMA_DONT_USE_CXX11_MUTEX))
  3749. if(x.sync_state == 1)
  3750. {
  3751. x.cache_mutex.lock();
  3752. if(x.sync_state == 1)
  3753. {
  3754. (*this).init(x.cache);
  3755. init_done = true;
  3756. }
  3757. x.cache_mutex.unlock();
  3758. }
  3759. #else
  3760. if(x.sync_state == 1)
  3761. {
  3762. (*this).init(x.cache);
  3763. init_done = true;
  3764. }
  3765. #endif
  3766. if(init_done == false)
  3767. {
  3768. (*this).init_simple(x);
  3769. }
  3770. }
  3771. template<typename eT>
  3772. inline
  3773. void
  3774. SpMat<eT>::init(const MapMat<eT>& x)
  3775. {
  3776. arma_extra_debug_sigprint();
  3777. const uword x_n_rows = x.n_rows;
  3778. const uword x_n_cols = x.n_cols;
  3779. const uword x_n_nz = x.get_n_nonzero();
  3780. init(x_n_rows, x_n_cols, x_n_nz);
  3781. if(x_n_nz == 0) { return; }
  3782. typename MapMat<eT>::map_type& x_map_ref = *(x.map_ptr);
  3783. typename MapMat<eT>::map_type::const_iterator x_it = x_map_ref.begin();
  3784. uword x_col = 0;
  3785. uword x_col_index_start = 0;
  3786. uword x_col_index_endp1 = x_n_rows;
  3787. for(uword i=0; i < x_n_nz; ++i)
  3788. {
  3789. const std::pair<uword, eT>& x_entry = (*x_it);
  3790. const uword x_index = x_entry.first;
  3791. const eT x_val = x_entry.second;
  3792. // have we gone past the curent column?
  3793. if(x_index >= x_col_index_endp1)
  3794. {
  3795. x_col = x_index / x_n_rows;
  3796. x_col_index_start = x_col * x_n_rows;
  3797. x_col_index_endp1 = x_col_index_start + x_n_rows;
  3798. }
  3799. const uword x_row = x_index - x_col_index_start;
  3800. // // sanity check
  3801. //
  3802. // const uword tmp_x_row = x_index % x_n_rows;
  3803. // const uword tmp_x_col = x_index / x_n_rows;
  3804. //
  3805. // if(x_row != tmp_x_row) { cout << "x_row != tmp_x_row" << endl; exit(-1); }
  3806. // if(x_col != tmp_x_col) { cout << "x_col != tmp_x_col" << endl; exit(-1); }
  3807. access::rw(values[i]) = x_val;
  3808. access::rw(row_indices[i]) = x_row;
  3809. access::rw(col_ptrs[ x_col + 1 ])++;
  3810. ++x_it;
  3811. }
  3812. for(uword i = 0; i < x_n_cols; ++i)
  3813. {
  3814. access::rw(col_ptrs[i + 1]) += col_ptrs[i];
  3815. }
  3816. // // OLD METHOD
  3817. //
  3818. // for(uword i=0; i < x_n_nz; ++i)
  3819. // {
  3820. // const std::pair<uword, eT>& x_entry = (*x_it);
  3821. //
  3822. // const uword x_index = x_entry.first;
  3823. // const eT x_val = x_entry.second;
  3824. //
  3825. // const uword x_row = x_index % x_n_rows;
  3826. // const uword x_col = x_index / x_n_rows;
  3827. //
  3828. // access::rw(values[i]) = x_val;
  3829. // access::rw(row_indices[i]) = x_row;
  3830. //
  3831. // access::rw(col_ptrs[ x_col + 1 ])++;
  3832. //
  3833. // ++x_it;
  3834. // }
  3835. //
  3836. //
  3837. // for(uword i = 0; i < x_n_cols; ++i)
  3838. // {
  3839. // access::rw(col_ptrs[i + 1]) += col_ptrs[i];
  3840. // }
  3841. }
  3842. template<typename eT>
  3843. inline
  3844. void
  3845. SpMat<eT>::init_simple(const SpMat<eT>& x)
  3846. {
  3847. arma_extra_debug_sigprint();
  3848. if(this == &x) { return; }
  3849. init(x.n_rows, x.n_cols, x.n_nonzero);
  3850. if(x.values ) { arrayops::copy(access::rwp(values), x.values, x.n_nonzero + 1); }
  3851. if(x.row_indices) { arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1); }
  3852. if(x.col_ptrs ) { arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1); }
  3853. }
  3854. template<typename eT>
  3855. inline
  3856. void
  3857. SpMat<eT>::init_batch_std(const Mat<uword>& locs, const Mat<eT>& vals, const bool sort_locations)
  3858. {
  3859. arma_extra_debug_sigprint();
  3860. // Resize to correct number of elements.
  3861. mem_resize(vals.n_elem);
  3862. // Reset column pointers to zero.
  3863. arrayops::fill_zeros(access::rwp(col_ptrs), n_cols + 1);
  3864. bool actually_sorted = true;
  3865. if(sort_locations == true)
  3866. {
  3867. // check if we really need a time consuming sort
  3868. const uword locs_n_cols = locs.n_cols;
  3869. for(uword i = 1; i < locs_n_cols; ++i)
  3870. {
  3871. const uword* locs_i = locs.colptr(i );
  3872. const uword* locs_im1 = locs.colptr(i-1);
  3873. const uword row_i = locs_i[0];
  3874. const uword col_i = locs_i[1];
  3875. const uword row_im1 = locs_im1[0];
  3876. const uword col_im1 = locs_im1[1];
  3877. if( (col_i < col_im1) || ((col_i == col_im1) && (row_i <= row_im1)) )
  3878. {
  3879. actually_sorted = false;
  3880. break;
  3881. }
  3882. }
  3883. if(actually_sorted == false)
  3884. {
  3885. // see op_sort_index_bones.hpp for the definition of arma_sort_index_packet and arma_sort_index_helper_ascend
  3886. std::vector< arma_sort_index_packet<uword> > packet_vec(locs_n_cols);
  3887. const uword* locs_mem = locs.memptr();
  3888. for(uword i = 0; i < locs_n_cols; ++i)
  3889. {
  3890. const uword row = (*locs_mem); locs_mem++;
  3891. const uword col = (*locs_mem); locs_mem++;
  3892. packet_vec[i].val = (col * n_rows) + row;
  3893. packet_vec[i].index = i;
  3894. }
  3895. arma_sort_index_helper_ascend<uword> comparator;
  3896. std::sort( packet_vec.begin(), packet_vec.end(), comparator );
  3897. // insert the elements in the sorted order
  3898. for(uword i = 0; i < locs_n_cols; ++i)
  3899. {
  3900. const uword index = packet_vec[i].index;
  3901. const uword* locs_i = locs.colptr(index);
  3902. const uword row_i = locs_i[0];
  3903. const uword col_i = locs_i[1];
  3904. arma_debug_check( ( (row_i >= n_rows) || (col_i >= n_cols) ), "SpMat::SpMat(): invalid row or column index" );
  3905. if(i > 0)
  3906. {
  3907. const uword prev_index = packet_vec[i-1].index;
  3908. const uword* locs_im1 = locs.colptr(prev_index);
  3909. const uword row_im1 = locs_im1[0];
  3910. const uword col_im1 = locs_im1[1];
  3911. arma_debug_check( ( (row_i == row_im1) && (col_i == col_im1) ), "SpMat::SpMat(): detected identical locations" );
  3912. }
  3913. access::rw(values[i]) = vals[index];
  3914. access::rw(row_indices[i]) = row_i;
  3915. access::rw(col_ptrs[ col_i + 1 ])++;
  3916. }
  3917. }
  3918. }
  3919. if( (sort_locations == false) || (actually_sorted == true) )
  3920. {
  3921. // Now set the values and row indices correctly.
  3922. // Increment the column pointers in each column (so they are column "counts").
  3923. const uword locs_n_cols = locs.n_cols;
  3924. for(uword i=0; i < locs_n_cols; ++i)
  3925. {
  3926. const uword* locs_i = locs.colptr(i);
  3927. const uword row_i = locs_i[0];
  3928. const uword col_i = locs_i[1];
  3929. arma_debug_check( ( (row_i >= n_rows) || (col_i >= n_cols) ), "SpMat::SpMat(): invalid row or column index" );
  3930. if(i > 0)
  3931. {
  3932. const uword* locs_im1 = locs.colptr(i-1);
  3933. const uword row_im1 = locs_im1[0];
  3934. const uword col_im1 = locs_im1[1];
  3935. arma_debug_check
  3936. (
  3937. ( (col_i < col_im1) || ((col_i == col_im1) && (row_i < row_im1)) ),
  3938. "SpMat::SpMat(): out of order points; either pass sort_locations = true, or sort points in column-major ordering"
  3939. );
  3940. arma_debug_check( ( (col_i == col_im1) && (row_i == row_im1) ), "SpMat::SpMat(): detected identical locations" );
  3941. }
  3942. access::rw(values[i]) = vals[i];
  3943. access::rw(row_indices[i]) = row_i;
  3944. access::rw(col_ptrs[ col_i + 1 ])++;
  3945. }
  3946. }
  3947. // Now fix the column pointers.
  3948. for(uword i = 0; i < n_cols; ++i)
  3949. {
  3950. access::rw(col_ptrs[i + 1]) += col_ptrs[i];
  3951. }
  3952. }
  3953. template<typename eT>
  3954. inline
  3955. void
  3956. SpMat<eT>::init_batch_add(const Mat<uword>& locs, const Mat<eT>& vals, const bool sort_locations)
  3957. {
  3958. arma_extra_debug_sigprint();
  3959. if(locs.n_cols < 2)
  3960. {
  3961. init_batch_std(locs, vals, false);
  3962. return;
  3963. }
  3964. // Reset column pointers to zero.
  3965. arrayops::fill_zeros(access::rwp(col_ptrs), n_cols + 1);
  3966. bool actually_sorted = true;
  3967. if(sort_locations == true)
  3968. {
  3969. // sort_index() uses std::sort() which may use quicksort... so we better
  3970. // make sure it's not already sorted before taking an O(N^2) sort penalty.
  3971. for(uword i = 1; i < locs.n_cols; ++i)
  3972. {
  3973. const uword* locs_i = locs.colptr(i );
  3974. const uword* locs_im1 = locs.colptr(i-1);
  3975. if( (locs_i[1] < locs_im1[1]) || (locs_i[1] == locs_im1[1] && locs_i[0] <= locs_im1[0]) )
  3976. {
  3977. actually_sorted = false;
  3978. break;
  3979. }
  3980. }
  3981. if(actually_sorted == false)
  3982. {
  3983. // This may not be the fastest possible implementation but it maximizes code reuse.
  3984. Col<uword> abslocs(locs.n_cols);
  3985. for(uword i = 0; i < locs.n_cols; ++i)
  3986. {
  3987. const uword* locs_i = locs.colptr(i);
  3988. abslocs[i] = locs_i[1] * n_rows + locs_i[0];
  3989. }
  3990. uvec sorted_indices = sort_index(abslocs); // Ascending sort.
  3991. // work out the number of unique elments
  3992. uword n_unique = 1; // first element is unique
  3993. for(uword i=1; i < sorted_indices.n_elem; ++i)
  3994. {
  3995. const uword* locs_i = locs.colptr( sorted_indices[i ] );
  3996. const uword* locs_im1 = locs.colptr( sorted_indices[i-1] );
  3997. if( (locs_i[1] != locs_im1[1]) || (locs_i[0] != locs_im1[0]) ) { ++n_unique; }
  3998. }
  3999. // resize to correct number of elements
  4000. mem_resize(n_unique);
  4001. // Now we add the elements in this sorted order.
  4002. uword count = 0;
  4003. // first element
  4004. {
  4005. const uword i = 0;
  4006. const uword* locs_i = locs.colptr( sorted_indices[i] );
  4007. arma_debug_check( ( (locs_i[0] >= n_rows) || (locs_i[1] >= n_cols) ), "SpMat::SpMat(): invalid row or column index" );
  4008. access::rw(values[count]) = vals[ sorted_indices[i] ];
  4009. access::rw(row_indices[count]) = locs_i[0];
  4010. access::rw(col_ptrs[ locs_i[1] + 1 ])++;
  4011. }
  4012. for(uword i=1; i < sorted_indices.n_elem; ++i)
  4013. {
  4014. const uword* locs_i = locs.colptr( sorted_indices[i ] );
  4015. const uword* locs_im1 = locs.colptr( sorted_indices[i-1] );
  4016. arma_debug_check( ( (locs_i[0] >= n_rows) || (locs_i[1] >= n_cols) ), "SpMat::SpMat(): invalid row or column index" );
  4017. if( (locs_i[1] == locs_im1[1]) && (locs_i[0] == locs_im1[0]) )
  4018. {
  4019. access::rw(values[count]) += vals[ sorted_indices[i] ];
  4020. }
  4021. else
  4022. {
  4023. count++;
  4024. access::rw(values[count]) = vals[ sorted_indices[i] ];
  4025. access::rw(row_indices[count]) = locs_i[0];
  4026. access::rw(col_ptrs[ locs_i[1] + 1 ])++;
  4027. }
  4028. }
  4029. }
  4030. }
  4031. if( (sort_locations == false) || (actually_sorted == true) )
  4032. {
  4033. // work out the number of unique elments
  4034. uword n_unique = 1; // first element is unique
  4035. for(uword i=1; i < locs.n_cols; ++i)
  4036. {
  4037. const uword* locs_i = locs.colptr(i );
  4038. const uword* locs_im1 = locs.colptr(i-1);
  4039. if( (locs_i[1] != locs_im1[1]) || (locs_i[0] != locs_im1[0]) ) { ++n_unique; }
  4040. }
  4041. // resize to correct number of elements
  4042. mem_resize(n_unique);
  4043. // Now set the values and row indices correctly.
  4044. // Increment the column pointers in each column (so they are column "counts").
  4045. uword count = 0;
  4046. // first element
  4047. {
  4048. const uword i = 0;
  4049. const uword* locs_i = locs.colptr(i);
  4050. arma_debug_check( ( (locs_i[0] >= n_rows) || (locs_i[1] >= n_cols) ), "SpMat::SpMat(): invalid row or column index" );
  4051. access::rw(values[count]) = vals[i];
  4052. access::rw(row_indices[count]) = locs_i[0];
  4053. access::rw(col_ptrs[ locs_i[1] + 1 ])++;
  4054. }
  4055. for(uword i=1; i < locs.n_cols; ++i)
  4056. {
  4057. const uword* locs_i = locs.colptr(i );
  4058. const uword* locs_im1 = locs.colptr(i-1);
  4059. arma_debug_check( ( (locs_i[0] >= n_rows) || (locs_i[1] >= n_cols) ), "SpMat::SpMat(): invalid row or column index" );
  4060. arma_debug_check
  4061. (
  4062. ( (locs_i[1] < locs_im1[1]) || (locs_i[1] == locs_im1[1] && locs_i[0] < locs_im1[0]) ),
  4063. "SpMat::SpMat(): out of order points; either pass sort_locations = true, or sort points in column-major ordering"
  4064. );
  4065. if( (locs_i[1] == locs_im1[1]) && (locs_i[0] == locs_im1[0]) )
  4066. {
  4067. access::rw(values[count]) += vals[i];
  4068. }
  4069. else
  4070. {
  4071. count++;
  4072. access::rw(values[count]) = vals[i];
  4073. access::rw(row_indices[count]) = locs_i[0];
  4074. access::rw(col_ptrs[ locs_i[1] + 1 ])++;
  4075. }
  4076. }
  4077. }
  4078. // Now fix the column pointers.
  4079. for(uword i = 0; i < n_cols; ++i)
  4080. {
  4081. access::rw(col_ptrs[i + 1]) += col_ptrs[i];
  4082. }
  4083. }
  4084. //! constructor used by SpRow and SpCol classes
  4085. template<typename eT>
  4086. inline
  4087. SpMat<eT>::SpMat(const arma_vec_indicator&, const uword in_vec_state)
  4088. : n_rows(0)
  4089. , n_cols(0)
  4090. , n_elem(0)
  4091. , n_nonzero(0)
  4092. , vec_state(in_vec_state)
  4093. , values(NULL)
  4094. , row_indices(NULL)
  4095. , col_ptrs(NULL)
  4096. {
  4097. arma_extra_debug_sigprint_this(this);
  4098. const uword in_n_rows = (in_vec_state == 2) ? 1 : 0;
  4099. const uword in_n_cols = (in_vec_state == 1) ? 1 : 0;
  4100. init_cold(in_n_rows, in_n_cols);
  4101. }
  4102. //! constructor used by SpRow and SpCol classes
  4103. template<typename eT>
  4104. inline
  4105. SpMat<eT>::SpMat(const arma_vec_indicator&, const uword in_n_rows, const uword in_n_cols, const uword in_vec_state)
  4106. : n_rows(0)
  4107. , n_cols(0)
  4108. , n_elem(0)
  4109. , n_nonzero(0)
  4110. , vec_state(in_vec_state)
  4111. , values(NULL)
  4112. , row_indices(NULL)
  4113. , col_ptrs(NULL)
  4114. {
  4115. arma_extra_debug_sigprint_this(this);
  4116. init_cold(in_n_rows, in_n_cols);
  4117. }
  4118. template<typename eT>
  4119. inline
  4120. void
  4121. SpMat<eT>::mem_resize(const uword new_n_nonzero)
  4122. {
  4123. arma_extra_debug_sigprint();
  4124. invalidate_cache(); // placed here, as mem_resize() is used during matrix modification
  4125. if(n_nonzero == new_n_nonzero) { return; }
  4126. eT* new_values = memory::acquire<eT> (new_n_nonzero + 1);
  4127. uword* new_row_indices = memory::acquire<uword>(new_n_nonzero + 1);
  4128. if( (n_nonzero > 0 ) && (new_n_nonzero > 0) )
  4129. {
  4130. // Copy old elements.
  4131. uword copy_len = (std::min)(n_nonzero, new_n_nonzero);
  4132. arrayops::copy(new_values, values, copy_len);
  4133. arrayops::copy(new_row_indices, row_indices, copy_len);
  4134. }
  4135. if(values) { memory::release(access::rw(values)); }
  4136. if(row_indices) { memory::release(access::rw(row_indices)); }
  4137. access::rw(values) = new_values;
  4138. access::rw(row_indices) = new_row_indices;
  4139. // Set the "fake end" of the matrix by setting the last value and row index to 0.
  4140. // This helps the iterators work correctly.
  4141. access::rw( values[new_n_nonzero]) = 0;
  4142. access::rw(row_indices[new_n_nonzero]) = 0;
  4143. access::rw(n_nonzero) = new_n_nonzero;
  4144. }
  4145. template<typename eT>
  4146. inline
  4147. void
  4148. SpMat<eT>::sync() const
  4149. {
  4150. arma_extra_debug_sigprint();
  4151. sync_csc();
  4152. }
  4153. template<typename eT>
  4154. inline
  4155. void
  4156. SpMat<eT>::remove_zeros()
  4157. {
  4158. arma_extra_debug_sigprint();
  4159. sync_csc();
  4160. invalidate_cache(); // placed here, as remove_zeros() is used during matrix modification
  4161. const uword old_n_nonzero = n_nonzero;
  4162. uword new_n_nonzero = 0;
  4163. const eT* old_values = values;
  4164. for(uword i=0; i < old_n_nonzero; ++i)
  4165. {
  4166. new_n_nonzero += (old_values[i] != eT(0)) ? uword(1) : uword(0);
  4167. }
  4168. if(new_n_nonzero != old_n_nonzero)
  4169. {
  4170. if(new_n_nonzero == 0) { init(n_rows, n_cols); return; }
  4171. SpMat<eT> tmp(arma_reserve_indicator(), n_rows, n_cols, new_n_nonzero);
  4172. uword new_index = 0;
  4173. const_iterator it = begin();
  4174. const_iterator it_end = end();
  4175. for(; it != it_end; ++it)
  4176. {
  4177. const eT val = eT(*it);
  4178. if(val != eT(0))
  4179. {
  4180. access::rw(tmp.values[new_index]) = val;
  4181. access::rw(tmp.row_indices[new_index]) = it.row();
  4182. access::rw(tmp.col_ptrs[it.col() + 1])++;
  4183. ++new_index;
  4184. }
  4185. }
  4186. for(uword i=0; i < n_cols; ++i)
  4187. {
  4188. access::rw(tmp.col_ptrs[i + 1]) += tmp.col_ptrs[i];
  4189. }
  4190. steal_mem(tmp);
  4191. }
  4192. }
  4193. // Steal memory from another matrix.
  4194. template<typename eT>
  4195. inline
  4196. void
  4197. SpMat<eT>::steal_mem(SpMat<eT>& x)
  4198. {
  4199. arma_extra_debug_sigprint();
  4200. if(this == &x) { return; }
  4201. bool layout_ok = false;
  4202. if((*this).vec_state == x.vec_state)
  4203. {
  4204. layout_ok = true;
  4205. }
  4206. else
  4207. {
  4208. if( ((*this).vec_state == 1) && (x.n_cols == 1) ) { layout_ok = true; }
  4209. if( ((*this).vec_state == 2) && (x.n_rows == 1) ) { layout_ok = true; }
  4210. }
  4211. if(layout_ok)
  4212. {
  4213. x.sync_csc();
  4214. steal_mem_simple(x);
  4215. x.invalidate_cache();
  4216. invalidate_cache();
  4217. }
  4218. else
  4219. {
  4220. (*this).operator=(x);
  4221. }
  4222. }
  4223. template<typename eT>
  4224. inline
  4225. void
  4226. SpMat<eT>::steal_mem_simple(SpMat<eT>& x)
  4227. {
  4228. arma_extra_debug_sigprint();
  4229. if(this == &x) { return; }
  4230. if(values ) { memory::release(access::rw(values)); }
  4231. if(row_indices) { memory::release(access::rw(row_indices)); }
  4232. if(col_ptrs ) { memory::release(access::rw(col_ptrs)); }
  4233. access::rw(n_rows) = x.n_rows;
  4234. access::rw(n_cols) = x.n_cols;
  4235. access::rw(n_elem) = x.n_elem;
  4236. access::rw(n_nonzero) = x.n_nonzero;
  4237. access::rw(values) = x.values;
  4238. access::rw(row_indices) = x.row_indices;
  4239. access::rw(col_ptrs) = x.col_ptrs;
  4240. // Set other matrix to empty.
  4241. access::rw(x.n_rows) = 0;
  4242. access::rw(x.n_cols) = 0;
  4243. access::rw(x.n_elem) = 0;
  4244. access::rw(x.n_nonzero) = 0;
  4245. access::rw(x.values) = NULL;
  4246. access::rw(x.row_indices) = NULL;
  4247. access::rw(x.col_ptrs) = NULL;
  4248. }
  4249. template<typename eT>
  4250. template<typename T1, typename Functor>
  4251. arma_hot
  4252. inline
  4253. void
  4254. SpMat<eT>::init_xform(const SpBase<eT,T1>& A, const Functor& func)
  4255. {
  4256. arma_extra_debug_sigprint();
  4257. // if possible, avoid doing a copy and instead apply func to the generated elements
  4258. if(SpProxy<T1>::Q_is_generated)
  4259. {
  4260. (*this) = A.get_ref();
  4261. const uword nnz = n_nonzero;
  4262. eT* t_values = access::rwp(values);
  4263. bool has_zero = false;
  4264. for(uword i=0; i < nnz; ++i)
  4265. {
  4266. eT& t_values_i = t_values[i];
  4267. t_values_i = func(t_values_i);
  4268. if(t_values_i == eT(0)) { has_zero = true; }
  4269. }
  4270. if(has_zero) { remove_zeros(); }
  4271. }
  4272. else
  4273. {
  4274. init_xform_mt(A.get_ref(), func);
  4275. }
  4276. }
  4277. template<typename eT>
  4278. template<typename eT2, typename T1, typename Functor>
  4279. arma_hot
  4280. inline
  4281. void
  4282. SpMat<eT>::init_xform_mt(const SpBase<eT2,T1>& A, const Functor& func)
  4283. {
  4284. arma_extra_debug_sigprint();
  4285. const SpProxy<T1> P(A.get_ref());
  4286. if( (P.is_alias(*this) == true) || (is_SpMat<typename SpProxy<T1>::stored_type>::value == true) )
  4287. {
  4288. // NOTE: unwrap_spmat will convert a submatrix to a matrix, which in effect takes care of aliasing with submatrices;
  4289. // NOTE: however, when more delayed ops are implemented, more elaborate handling of aliasing will be necessary
  4290. const unwrap_spmat<typename SpProxy<T1>::stored_type> tmp(P.Q);
  4291. const SpMat<eT2>& x = tmp.M;
  4292. if(void_ptr(this) != void_ptr(&x))
  4293. {
  4294. init(x.n_rows, x.n_cols, x.n_nonzero);
  4295. arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1);
  4296. arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1);
  4297. }
  4298. // initialise the elements array with a transformed version of the elements from x
  4299. const uword nnz = n_nonzero;
  4300. const eT2* x_values = x.values;
  4301. eT* t_values = access::rwp(values);
  4302. bool has_zero = false;
  4303. for(uword i=0; i < nnz; ++i)
  4304. {
  4305. eT& t_values_i = t_values[i];
  4306. t_values_i = func(x_values[i]); // NOTE: func() must produce a value of type eT (ie. act as a convertor between eT2 and eT)
  4307. if(t_values_i == eT(0)) { has_zero = true; }
  4308. }
  4309. if(has_zero) { remove_zeros(); }
  4310. }
  4311. else
  4312. {
  4313. init(P.get_n_rows(), P.get_n_cols(), P.get_n_nonzero());
  4314. typename SpProxy<T1>::const_iterator_type it = P.begin();
  4315. typename SpProxy<T1>::const_iterator_type it_end = P.end();
  4316. bool has_zero = false;
  4317. while(it != it_end)
  4318. {
  4319. const eT val = func(*it); // NOTE: func() must produce a value of type eT (ie. act as a convertor between eT2 and eT)
  4320. if(val == eT(0)) { has_zero = true; }
  4321. access::rw(row_indices[it.pos()]) = it.row();
  4322. access::rw(values[it.pos()]) = val;
  4323. ++access::rw(col_ptrs[it.col() + 1]);
  4324. ++it;
  4325. }
  4326. // Now sum column pointers.
  4327. for(uword c = 1; c <= n_cols; ++c)
  4328. {
  4329. access::rw(col_ptrs[c]) += col_ptrs[c - 1];
  4330. }
  4331. if(has_zero) { remove_zeros(); }
  4332. }
  4333. }
  4334. template<typename eT>
  4335. arma_inline
  4336. bool
  4337. SpMat<eT>::is_alias(const SpMat<eT>& X) const
  4338. {
  4339. return (&X == this);
  4340. }
  4341. template<typename eT>
  4342. inline
  4343. typename SpMat<eT>::iterator
  4344. SpMat<eT>::begin()
  4345. {
  4346. arma_extra_debug_sigprint();
  4347. sync_csc();
  4348. return iterator(*this);
  4349. }
  4350. template<typename eT>
  4351. inline
  4352. typename SpMat<eT>::const_iterator
  4353. SpMat<eT>::begin() const
  4354. {
  4355. arma_extra_debug_sigprint();
  4356. sync_csc();
  4357. return const_iterator(*this);
  4358. }
  4359. template<typename eT>
  4360. inline
  4361. typename SpMat<eT>::const_iterator
  4362. SpMat<eT>::cbegin() const
  4363. {
  4364. arma_extra_debug_sigprint();
  4365. sync_csc();
  4366. return const_iterator(*this);
  4367. }
  4368. template<typename eT>
  4369. inline
  4370. typename SpMat<eT>::iterator
  4371. SpMat<eT>::end()
  4372. {
  4373. sync_csc();
  4374. return iterator(*this, 0, n_cols, n_nonzero);
  4375. }
  4376. template<typename eT>
  4377. inline
  4378. typename SpMat<eT>::const_iterator
  4379. SpMat<eT>::end() const
  4380. {
  4381. sync_csc();
  4382. return const_iterator(*this, 0, n_cols, n_nonzero);
  4383. }
  4384. template<typename eT>
  4385. inline
  4386. typename SpMat<eT>::const_iterator
  4387. SpMat<eT>::cend() const
  4388. {
  4389. sync_csc();
  4390. return const_iterator(*this, 0, n_cols, n_nonzero);
  4391. }
  4392. template<typename eT>
  4393. inline
  4394. typename SpMat<eT>::col_iterator
  4395. SpMat<eT>::begin_col(const uword col_num)
  4396. {
  4397. sync_csc();
  4398. return col_iterator(*this, 0, col_num);
  4399. }
  4400. template<typename eT>
  4401. inline
  4402. typename SpMat<eT>::const_col_iterator
  4403. SpMat<eT>::begin_col(const uword col_num) const
  4404. {
  4405. sync_csc();
  4406. return const_col_iterator(*this, 0, col_num);
  4407. }
  4408. template<typename eT>
  4409. inline
  4410. typename SpMat<eT>::col_iterator
  4411. SpMat<eT>::begin_col_no_sync(const uword col_num)
  4412. {
  4413. return col_iterator(*this, 0, col_num);
  4414. }
  4415. template<typename eT>
  4416. inline
  4417. typename SpMat<eT>::const_col_iterator
  4418. SpMat<eT>::begin_col_no_sync(const uword col_num) const
  4419. {
  4420. return const_col_iterator(*this, 0, col_num);
  4421. }
  4422. template<typename eT>
  4423. inline
  4424. typename SpMat<eT>::col_iterator
  4425. SpMat<eT>::end_col(const uword col_num)
  4426. {
  4427. sync_csc();
  4428. return col_iterator(*this, 0, col_num + 1);
  4429. }
  4430. template<typename eT>
  4431. inline
  4432. typename SpMat<eT>::const_col_iterator
  4433. SpMat<eT>::end_col(const uword col_num) const
  4434. {
  4435. sync_csc();
  4436. return const_col_iterator(*this, 0, col_num + 1);
  4437. }
  4438. template<typename eT>
  4439. inline
  4440. typename SpMat<eT>::col_iterator
  4441. SpMat<eT>::end_col_no_sync(const uword col_num)
  4442. {
  4443. return col_iterator(*this, 0, col_num + 1);
  4444. }
  4445. template<typename eT>
  4446. inline
  4447. typename SpMat<eT>::const_col_iterator
  4448. SpMat<eT>::end_col_no_sync(const uword col_num) const
  4449. {
  4450. return const_col_iterator(*this, 0, col_num + 1);
  4451. }
  4452. template<typename eT>
  4453. inline
  4454. typename SpMat<eT>::row_iterator
  4455. SpMat<eT>::begin_row(const uword row_num)
  4456. {
  4457. sync_csc();
  4458. return row_iterator(*this, row_num, 0);
  4459. }
  4460. template<typename eT>
  4461. inline
  4462. typename SpMat<eT>::const_row_iterator
  4463. SpMat<eT>::begin_row(const uword row_num) const
  4464. {
  4465. sync_csc();
  4466. return const_row_iterator(*this, row_num, 0);
  4467. }
  4468. template<typename eT>
  4469. inline
  4470. typename SpMat<eT>::row_iterator
  4471. SpMat<eT>::end_row()
  4472. {
  4473. sync_csc();
  4474. return row_iterator(*this, n_nonzero);
  4475. }
  4476. template<typename eT>
  4477. inline
  4478. typename SpMat<eT>::const_row_iterator
  4479. SpMat<eT>::end_row() const
  4480. {
  4481. sync_csc();
  4482. return const_row_iterator(*this, n_nonzero);
  4483. }
  4484. template<typename eT>
  4485. inline
  4486. typename SpMat<eT>::row_iterator
  4487. SpMat<eT>::end_row(const uword row_num)
  4488. {
  4489. sync_csc();
  4490. return row_iterator(*this, row_num + 1, 0);
  4491. }
  4492. template<typename eT>
  4493. inline
  4494. typename SpMat<eT>::const_row_iterator
  4495. SpMat<eT>::end_row(const uword row_num) const
  4496. {
  4497. sync_csc();
  4498. return const_row_iterator(*this, row_num + 1, 0);
  4499. }
  4500. template<typename eT>
  4501. inline
  4502. typename SpMat<eT>::row_col_iterator
  4503. SpMat<eT>::begin_row_col()
  4504. {
  4505. sync_csc();
  4506. return begin();
  4507. }
  4508. template<typename eT>
  4509. inline
  4510. typename SpMat<eT>::const_row_col_iterator
  4511. SpMat<eT>::begin_row_col() const
  4512. {
  4513. sync_csc();
  4514. return begin();
  4515. }
  4516. template<typename eT>
  4517. inline typename SpMat<eT>::row_col_iterator
  4518. SpMat<eT>::end_row_col()
  4519. {
  4520. sync_csc();
  4521. return end();
  4522. }
  4523. template<typename eT>
  4524. inline
  4525. typename SpMat<eT>::const_row_col_iterator
  4526. SpMat<eT>::end_row_col() const
  4527. {
  4528. sync_csc();
  4529. return end();
  4530. }
  4531. template<typename eT>
  4532. inline
  4533. void
  4534. SpMat<eT>::clear()
  4535. {
  4536. (*this).reset();
  4537. }
  4538. template<typename eT>
  4539. inline
  4540. bool
  4541. SpMat<eT>::empty() const
  4542. {
  4543. return (n_elem == 0);
  4544. }
  4545. template<typename eT>
  4546. inline
  4547. uword
  4548. SpMat<eT>::size() const
  4549. {
  4550. return n_elem;
  4551. }
  4552. template<typename eT>
  4553. arma_inline
  4554. arma_warn_unused
  4555. SpMat_MapMat_val<eT>
  4556. SpMat<eT>::front()
  4557. {
  4558. arma_debug_check( (n_elem == 0), "SpMat::front(): matrix is empty" );
  4559. return SpMat_MapMat_val<eT>((*this), cache, 0, 0);
  4560. }
  4561. template<typename eT>
  4562. arma_inline
  4563. arma_warn_unused
  4564. eT
  4565. SpMat<eT>::front() const
  4566. {
  4567. arma_debug_check( (n_elem == 0), "SpMat::front(): matrix is empty" );
  4568. return get_value(0,0);
  4569. }
  4570. template<typename eT>
  4571. arma_inline
  4572. arma_warn_unused
  4573. SpMat_MapMat_val<eT>
  4574. SpMat<eT>::back()
  4575. {
  4576. arma_debug_check( (n_elem == 0), "SpMat::back(): matrix is empty" );
  4577. return SpMat_MapMat_val<eT>((*this), cache, n_rows-1, n_cols-1);
  4578. }
  4579. template<typename eT>
  4580. arma_inline
  4581. arma_warn_unused
  4582. eT
  4583. SpMat<eT>::back() const
  4584. {
  4585. arma_debug_check( (n_elem == 0), "SpMat::back(): matrix is empty" );
  4586. return get_value(n_rows-1, n_cols-1);
  4587. }
  4588. template<typename eT>
  4589. inline
  4590. arma_hot
  4591. arma_warn_unused
  4592. eT
  4593. SpMat<eT>::get_value(const uword i) const
  4594. {
  4595. const MapMat<eT>& const_cache = cache; // declare as const for clarity of intent
  4596. // get the element from the cache if it has more recent data than CSC
  4597. return (sync_state == 1) ? const_cache.operator[](i) : get_value_csc(i);
  4598. }
  4599. template<typename eT>
  4600. inline
  4601. arma_hot
  4602. arma_warn_unused
  4603. eT
  4604. SpMat<eT>::get_value(const uword in_row, const uword in_col) const
  4605. {
  4606. const MapMat<eT>& const_cache = cache; // declare as const for clarity of intent
  4607. // get the element from the cache if it has more recent data than CSC
  4608. return (sync_state == 1) ? const_cache.at(in_row, in_col) : get_value_csc(in_row, in_col);
  4609. }
  4610. template<typename eT>
  4611. inline
  4612. arma_hot
  4613. arma_warn_unused
  4614. eT
  4615. SpMat<eT>::get_value_csc(const uword i) const
  4616. {
  4617. // First convert to the actual location.
  4618. uword lcol = i / n_rows; // Integer division.
  4619. uword lrow = i % n_rows;
  4620. return get_value_csc(lrow, lcol);
  4621. }
  4622. template<typename eT>
  4623. inline
  4624. arma_hot
  4625. arma_warn_unused
  4626. const eT*
  4627. SpMat<eT>::find_value_csc(const uword in_row, const uword in_col) const
  4628. {
  4629. const uword col_offset = col_ptrs[in_col ];
  4630. const uword next_col_offset = col_ptrs[in_col + 1];
  4631. const uword* start_ptr = &row_indices[ col_offset];
  4632. const uword* end_ptr = &row_indices[next_col_offset];
  4633. const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, in_row); // binary search
  4634. if( (pos_ptr != end_ptr) && ((*pos_ptr) == in_row) )
  4635. {
  4636. const uword offset = uword(pos_ptr - start_ptr);
  4637. const uword index = offset + col_offset;
  4638. return &(values[index]);
  4639. }
  4640. return NULL;
  4641. }
  4642. template<typename eT>
  4643. inline
  4644. arma_hot
  4645. arma_warn_unused
  4646. eT
  4647. SpMat<eT>::get_value_csc(const uword in_row, const uword in_col) const
  4648. {
  4649. const eT* val_ptr = find_value_csc(in_row, in_col);
  4650. return (val_ptr != NULL) ? eT(*val_ptr) : eT(0);
  4651. }
  4652. template<typename eT>
  4653. inline
  4654. arma_hot
  4655. arma_warn_unused
  4656. bool
  4657. SpMat<eT>::try_set_value_csc(const uword in_row, const uword in_col, const eT in_val)
  4658. {
  4659. const eT* val_ptr = find_value_csc(in_row, in_col);
  4660. // element not found, ie. it's zero; fail if trying to set it to non-zero value
  4661. if(val_ptr == NULL) { return (in_val == eT(0)); }
  4662. // fail if trying to erase an existing element
  4663. if(in_val == eT(0)) { return false; }
  4664. access::rw(*val_ptr) = in_val;
  4665. invalidate_cache();
  4666. return true;
  4667. }
  4668. template<typename eT>
  4669. inline
  4670. arma_hot
  4671. arma_warn_unused
  4672. bool
  4673. SpMat<eT>::try_add_value_csc(const uword in_row, const uword in_col, const eT in_val)
  4674. {
  4675. const eT* val_ptr = find_value_csc(in_row, in_col);
  4676. // element not found, ie. it's zero; fail if trying to add a non-zero value
  4677. if(val_ptr == NULL) { return (in_val == eT(0)); }
  4678. const eT new_val = eT(*val_ptr) + in_val;
  4679. // fail if trying to erase an existing element
  4680. if(new_val == eT(0)) { return false; }
  4681. access::rw(*val_ptr) = new_val;
  4682. invalidate_cache();
  4683. return true;
  4684. }
  4685. template<typename eT>
  4686. inline
  4687. arma_hot
  4688. arma_warn_unused
  4689. bool
  4690. SpMat<eT>::try_sub_value_csc(const uword in_row, const uword in_col, const eT in_val)
  4691. {
  4692. const eT* val_ptr = find_value_csc(in_row, in_col);
  4693. // element not found, ie. it's zero; fail if trying to subtract a non-zero value
  4694. if(val_ptr == NULL) { return (in_val == eT(0)); }
  4695. const eT new_val = eT(*val_ptr) - in_val;
  4696. // fail if trying to erase an existing element
  4697. if(new_val == eT(0)) { return false; }
  4698. access::rw(*val_ptr) = new_val;
  4699. invalidate_cache();
  4700. return true;
  4701. }
  4702. template<typename eT>
  4703. inline
  4704. arma_hot
  4705. arma_warn_unused
  4706. bool
  4707. SpMat<eT>::try_mul_value_csc(const uword in_row, const uword in_col, const eT in_val)
  4708. {
  4709. const eT* val_ptr = find_value_csc(in_row, in_col);
  4710. // element not found, ie. it's zero; succeed if given value is finite; zero multiplied by anything is zero, except for nan and inf
  4711. if(val_ptr == NULL) { return arma_isfinite(in_val); }
  4712. const eT new_val = eT(*val_ptr) * in_val;
  4713. // fail if trying to erase an existing element
  4714. if(new_val == eT(0)) { return false; }
  4715. access::rw(*val_ptr) = new_val;
  4716. invalidate_cache();
  4717. return true;
  4718. }
  4719. template<typename eT>
  4720. inline
  4721. arma_hot
  4722. arma_warn_unused
  4723. bool
  4724. SpMat<eT>::try_div_value_csc(const uword in_row, const uword in_col, const eT in_val)
  4725. {
  4726. const eT* val_ptr = find_value_csc(in_row, in_col);
  4727. // element not found, ie. it's zero; succeed if given value is not zero and not nan; zero divided by anything is zero, except for zero and nan
  4728. if(val_ptr == NULL) { return ((in_val != eT(0)) && (arma_isnan(in_val) == false)); }
  4729. const eT new_val = eT(*val_ptr) / in_val;
  4730. // fail if trying to erase an existing element
  4731. if(new_val == eT(0)) { return false; }
  4732. access::rw(*val_ptr) = new_val;
  4733. invalidate_cache();
  4734. return true;
  4735. }
  4736. /**
  4737. * Insert an element at the given position, and return a reference to it.
  4738. * The element will be set to 0, unless otherwise specified.
  4739. * If the element already exists, its value will be overwritten.
  4740. */
  4741. template<typename eT>
  4742. inline
  4743. arma_warn_unused
  4744. eT&
  4745. SpMat<eT>::insert_element(const uword in_row, const uword in_col, const eT val)
  4746. {
  4747. arma_extra_debug_sigprint();
  4748. sync_csc();
  4749. invalidate_cache();
  4750. // We will assume the new element does not exist and begin the search for
  4751. // where to insert it. If we find that it already exists, we will then
  4752. // overwrite it.
  4753. uword colptr = col_ptrs[in_col ];
  4754. uword next_colptr = col_ptrs[in_col + 1];
  4755. uword pos = colptr; // The position in the matrix of this value.
  4756. if(colptr != next_colptr)
  4757. {
  4758. // There are other elements in this column, so we must find where this
  4759. // element will fit as compared to those.
  4760. while(pos < next_colptr && in_row > row_indices[pos])
  4761. {
  4762. pos++;
  4763. }
  4764. // We aren't inserting into the last position, so it is still possible
  4765. // that the element may exist.
  4766. if(pos != next_colptr && row_indices[pos] == in_row)
  4767. {
  4768. // It already exists. Then, just overwrite it.
  4769. access::rw(values[pos]) = val;
  4770. return access::rw(values[pos]);
  4771. }
  4772. }
  4773. //
  4774. // Element doesn't exist, so we have to insert it
  4775. //
  4776. // We have to update the rest of the column pointers.
  4777. for(uword i = in_col + 1; i < n_cols + 1; i++)
  4778. {
  4779. access::rw(col_ptrs[i])++; // We are only inserting one new element.
  4780. }
  4781. const uword old_n_nonzero = n_nonzero;
  4782. access::rw(n_nonzero)++; // Add to count of nonzero elements.
  4783. // Allocate larger memory.
  4784. eT* new_values = memory::acquire<eT> (n_nonzero + 1);
  4785. uword* new_row_indices = memory::acquire<uword>(n_nonzero + 1);
  4786. // Copy things over, before the new element.
  4787. if(pos > 0)
  4788. {
  4789. arrayops::copy(new_values, values, pos);
  4790. arrayops::copy(new_row_indices, row_indices, pos);
  4791. }
  4792. // Insert the new element.
  4793. new_values[pos] = val;
  4794. new_row_indices[pos] = in_row;
  4795. // Copy the rest of things over (including the extra element at the end).
  4796. arrayops::copy(new_values + pos + 1, values + pos, (old_n_nonzero - pos) + 1);
  4797. arrayops::copy(new_row_indices + pos + 1, row_indices + pos, (old_n_nonzero - pos) + 1);
  4798. // Assign new pointers.
  4799. if(values) { memory::release(access::rw(values)); }
  4800. if(row_indices) { memory::release(access::rw(row_indices)); }
  4801. access::rw(values) = new_values;
  4802. access::rw(row_indices) = new_row_indices;
  4803. return access::rw(values[pos]);
  4804. }
  4805. /**
  4806. * Delete an element at the given position.
  4807. */
  4808. template<typename eT>
  4809. inline
  4810. void
  4811. SpMat<eT>::delete_element(const uword in_row, const uword in_col)
  4812. {
  4813. arma_extra_debug_sigprint();
  4814. sync_csc();
  4815. invalidate_cache();
  4816. // We assume the element exists (although... it may not) and look for its
  4817. // exact position. If it doesn't exist... well, we don't need to do anything.
  4818. uword colptr = col_ptrs[in_col];
  4819. uword next_colptr = col_ptrs[in_col + 1];
  4820. if(colptr != next_colptr)
  4821. {
  4822. // There's at least one element in this column.
  4823. // Let's see if we are one of them.
  4824. for(uword pos = colptr; pos < next_colptr; pos++)
  4825. {
  4826. if(in_row == row_indices[pos])
  4827. {
  4828. --access::rw(n_nonzero); // Remove one from the count of nonzero elements.
  4829. // Found it. Now remove it.
  4830. // Make new arrays.
  4831. eT* new_values = memory::acquire<eT> (n_nonzero + 1);
  4832. uword* new_row_indices = memory::acquire<uword>(n_nonzero + 1);
  4833. if(pos > 0)
  4834. {
  4835. arrayops::copy(new_values, values, pos);
  4836. arrayops::copy(new_row_indices, row_indices, pos);
  4837. }
  4838. arrayops::copy(new_values + pos, values + pos + 1, (n_nonzero - pos) + 1);
  4839. arrayops::copy(new_row_indices + pos, row_indices + pos + 1, (n_nonzero - pos) + 1);
  4840. if(values) { memory::release(access::rw(values)); }
  4841. if(row_indices) { memory::release(access::rw(row_indices)); }
  4842. access::rw(values) = new_values;
  4843. access::rw(row_indices) = new_row_indices;
  4844. // And lastly, update all the column pointers (decrement by one).
  4845. for(uword i = in_col + 1; i < n_cols + 1; i++)
  4846. {
  4847. --access::rw(col_ptrs[i]); // We only removed one element.
  4848. }
  4849. return; // There is nothing left to do.
  4850. }
  4851. }
  4852. }
  4853. return; // The element does not exist, so there's nothing for us to do.
  4854. }
  4855. template<typename eT>
  4856. arma_inline
  4857. void
  4858. SpMat<eT>::invalidate_cache() const
  4859. {
  4860. arma_extra_debug_sigprint();
  4861. if(sync_state == 0) { return; }
  4862. cache.reset();
  4863. sync_state = 0;
  4864. }
  4865. template<typename eT>
  4866. arma_inline
  4867. void
  4868. SpMat<eT>::invalidate_csc() const
  4869. {
  4870. arma_extra_debug_sigprint();
  4871. sync_state = 1;
  4872. }
  4873. template<typename eT>
  4874. inline
  4875. void
  4876. SpMat<eT>::sync_cache() const
  4877. {
  4878. arma_extra_debug_sigprint();
  4879. // using approach adapted from http://preshing.com/20130930/double-checked-locking-is-fixed-in-cpp11/
  4880. //
  4881. // OpenMP mode:
  4882. // sync_state uses atomic read/write, which has an implied flush;
  4883. // flush is also implicitly executed at the entrance and the exit of critical section;
  4884. // data races are prevented by the 'critical' directive
  4885. //
  4886. // C++11 mode:
  4887. // underlying type for sync_state is std::atomic<int>;
  4888. // reading and writing to sync_state uses std::memory_order_seq_cst which has an implied fence;
  4889. // data races are prevented via the mutex
  4890. #if defined(ARMA_USE_OPENMP)
  4891. {
  4892. if(sync_state == 0)
  4893. {
  4894. #pragma omp critical (arma_SpMat_cache)
  4895. {
  4896. sync_cache_simple();
  4897. }
  4898. }
  4899. }
  4900. #elif (defined(ARMA_USE_CXX11) && !defined(ARMA_DONT_USE_CXX11_MUTEX))
  4901. {
  4902. if(sync_state == 0)
  4903. {
  4904. cache_mutex.lock();
  4905. sync_cache_simple();
  4906. cache_mutex.unlock();
  4907. }
  4908. }
  4909. #else
  4910. {
  4911. sync_cache_simple();
  4912. }
  4913. #endif
  4914. }
  4915. template<typename eT>
  4916. inline
  4917. void
  4918. SpMat<eT>::sync_cache_simple() const
  4919. {
  4920. arma_extra_debug_sigprint();
  4921. if(sync_state == 0)
  4922. {
  4923. cache = (*this);
  4924. sync_state = 2;
  4925. }
  4926. }
  4927. template<typename eT>
  4928. inline
  4929. void
  4930. SpMat<eT>::sync_csc() const
  4931. {
  4932. arma_extra_debug_sigprint();
  4933. #if defined(ARMA_USE_OPENMP)
  4934. if(sync_state == 1)
  4935. {
  4936. #pragma omp critical (arma_SpMat_cache)
  4937. {
  4938. sync_csc_simple();
  4939. }
  4940. }
  4941. #elif (defined(ARMA_USE_CXX11) && !defined(ARMA_DONT_USE_CXX11_MUTEX))
  4942. if(sync_state == 1)
  4943. {
  4944. cache_mutex.lock();
  4945. sync_csc_simple();
  4946. cache_mutex.unlock();
  4947. }
  4948. #else
  4949. {
  4950. sync_csc_simple();
  4951. }
  4952. #endif
  4953. }
  4954. template<typename eT>
  4955. inline
  4956. void
  4957. SpMat<eT>::sync_csc_simple() const
  4958. {
  4959. arma_extra_debug_sigprint();
  4960. // method:
  4961. // 1. construct temporary matrix to prevent the cache from getting zapped
  4962. // 2. steal memory from the temporary matrix
  4963. // sync_state is only set to 1 by non-const element access operators,
  4964. // so the shenanigans with const_cast are to satisfy the compiler
  4965. // see also the note in sync_cache() above
  4966. if(sync_state == 1)
  4967. {
  4968. SpMat<eT>& x = const_cast< SpMat<eT>& >(*this);
  4969. SpMat<eT> tmp(cache);
  4970. x.steal_mem_simple(tmp);
  4971. sync_state = 2;
  4972. }
  4973. }
  4974. //
  4975. // SpMat_aux
  4976. template<typename eT, typename T1>
  4977. inline
  4978. void
  4979. SpMat_aux::set_real(SpMat<eT>& out, const SpBase<eT,T1>& X)
  4980. {
  4981. arma_extra_debug_sigprint();
  4982. const unwrap_spmat<T1> tmp(X.get_ref());
  4983. const SpMat<eT>& A = tmp.M;
  4984. arma_debug_assert_same_size( out, A, "SpMat::set_real()" );
  4985. out = A;
  4986. }
  4987. template<typename eT, typename T1>
  4988. inline
  4989. void
  4990. SpMat_aux::set_imag(SpMat<eT>&, const SpBase<eT,T1>&)
  4991. {
  4992. arma_extra_debug_sigprint();
  4993. }
  4994. template<typename T, typename T1>
  4995. inline
  4996. void
  4997. SpMat_aux::set_real(SpMat< std::complex<T> >& out, const SpBase<T,T1>& X)
  4998. {
  4999. arma_extra_debug_sigprint();
  5000. typedef typename std::complex<T> eT;
  5001. const unwrap_spmat<T1> U(X.get_ref());
  5002. const SpMat<T>& Y = U.M;
  5003. arma_debug_assert_same_size(out, Y, "SpMat::set_real()");
  5004. SpMat<eT> tmp(Y,arma::imag(out)); // arma:: prefix required due to bugs in GCC 4.4 - 4.6
  5005. out.steal_mem(tmp);
  5006. }
  5007. template<typename T, typename T1>
  5008. inline
  5009. void
  5010. SpMat_aux::set_imag(SpMat< std::complex<T> >& out, const SpBase<T,T1>& X)
  5011. {
  5012. arma_extra_debug_sigprint();
  5013. typedef typename std::complex<T> eT;
  5014. const unwrap_spmat<T1> U(X.get_ref());
  5015. const SpMat<T>& Y = U.M;
  5016. arma_debug_assert_same_size(out, Y, "SpMat::set_imag()");
  5017. SpMat<eT> tmp(arma::real(out),Y); // arma:: prefix required due to bugs in GCC 4.4 - 4.6
  5018. out.steal_mem(tmp);
  5019. }
  5020. #ifdef ARMA_EXTRA_SPMAT_MEAT
  5021. #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPMAT_MEAT)
  5022. #endif
  5023. //! @}