fn_svds.hpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup fn_svds
  16. //! @{
  17. template<typename T1>
  18. inline
  19. bool
  20. svds_helper
  21. (
  22. Mat<typename T1::elem_type>& U,
  23. Col<typename T1::pod_type >& S,
  24. Mat<typename T1::elem_type>& V,
  25. const SpBase<typename T1::elem_type,T1>& X,
  26. const uword k,
  27. const typename T1::pod_type tol,
  28. const bool calc_UV,
  29. const typename arma_real_only<typename T1::elem_type>::result* junk = 0
  30. )
  31. {
  32. arma_extra_debug_sigprint();
  33. arma_ignore(junk);
  34. typedef typename T1::elem_type eT;
  35. typedef typename T1::pod_type T;
  36. arma_debug_check
  37. (
  38. ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ),
  39. "svds(): two or more output objects are the same object"
  40. );
  41. arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" );
  42. const unwrap_spmat<T1> tmp(X.get_ref());
  43. const SpMat<eT>& A = tmp.M;
  44. const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k );
  45. const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col<eT>(const_cast<eT*>(A.values), A.n_nonzero, false)))) : T(0);
  46. if(A_max == T(0))
  47. {
  48. // TODO: use reset instead ?
  49. S.zeros(kk);
  50. if(calc_UV)
  51. {
  52. U.eye(A.n_rows, kk);
  53. V.eye(A.n_cols, kk);
  54. }
  55. }
  56. else
  57. {
  58. SpMat<eT> C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) );
  59. SpMat<eT> B = A / A_max;
  60. SpMat<eT> Bt = B.t();
  61. C(0, A.n_rows, arma::size(B) ) = B;
  62. C(A.n_rows, 0, arma::size(Bt)) = Bt;
  63. Bt.reset();
  64. B.reset();
  65. Col<eT> eigval;
  66. Mat<eT> eigvec;
  67. const bool status = sp_auxlib::eigs_sym(eigval, eigvec, C, kk, "la", (tol / Datum<T>::sqrt2));
  68. if(status == false)
  69. {
  70. U.soft_reset();
  71. S.soft_reset();
  72. V.soft_reset();
  73. return false;
  74. }
  75. const T A_norm = max(eigval);
  76. const T tol2 = tol / Datum<T>::sqrt2 * A_norm;
  77. uvec indices = find(eigval > tol2);
  78. if(indices.n_elem > kk)
  79. {
  80. indices = indices.subvec(0,kk-1);
  81. }
  82. else
  83. if(indices.n_elem < kk)
  84. {
  85. const uvec indices2 = find(abs(eigval) <= tol2);
  86. const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) );
  87. if(N_extra > 0) { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); }
  88. }
  89. const uvec sorted_indices = sort_index(eigval, "descend");
  90. S = eigval.elem(sorted_indices); S *= A_max;
  91. if(calc_UV)
  92. {
  93. uvec U_row_indices(A.n_rows); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; }
  94. uvec V_row_indices(A.n_cols); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; }
  95. U = Datum<T>::sqrt2 * eigvec(U_row_indices, sorted_indices);
  96. V = Datum<T>::sqrt2 * eigvec(V_row_indices, sorted_indices);
  97. }
  98. }
  99. if(S.n_elem < k) { arma_debug_warn("svds(): found fewer singular values than specified"); }
  100. return true;
  101. }
  102. template<typename T1>
  103. inline
  104. bool
  105. svds_helper
  106. (
  107. Mat<typename T1::elem_type>& U,
  108. Col<typename T1::pod_type >& S,
  109. Mat<typename T1::elem_type>& V,
  110. const SpBase<typename T1::elem_type,T1>& X,
  111. const uword k,
  112. const typename T1::pod_type tol,
  113. const bool calc_UV,
  114. const typename arma_cx_only<typename T1::elem_type>::result* junk = 0
  115. )
  116. {
  117. arma_extra_debug_sigprint();
  118. arma_ignore(junk);
  119. typedef typename T1::elem_type eT;
  120. typedef typename T1::pod_type T;
  121. if(arma_config::arpack == false)
  122. {
  123. arma_stop_logic_error("svds(): use of ARPACK must be enabled for decomposition of complex matrices");
  124. return false;
  125. }
  126. arma_debug_check
  127. (
  128. ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ),
  129. "svds(): two or more output objects are the same object"
  130. );
  131. arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" );
  132. const unwrap_spmat<T1> tmp(X.get_ref());
  133. const SpMat<eT>& A = tmp.M;
  134. const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k );
  135. const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col<eT>(const_cast<eT*>(A.values), A.n_nonzero, false)))) : T(0);
  136. if(A_max == T(0))
  137. {
  138. // TODO: use reset instead ?
  139. S.zeros(kk);
  140. if(calc_UV)
  141. {
  142. U.eye(A.n_rows, kk);
  143. V.eye(A.n_cols, kk);
  144. }
  145. }
  146. else
  147. {
  148. SpMat<eT> C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) );
  149. SpMat<eT> B = A / A_max;
  150. SpMat<eT> Bt = B.t();
  151. C(0, A.n_rows, arma::size(B) ) = B;
  152. C(A.n_rows, 0, arma::size(Bt)) = Bt;
  153. Bt.reset();
  154. B.reset();
  155. Col<eT> eigval_tmp;
  156. Mat<eT> eigvec;
  157. const bool status = sp_auxlib::eigs_gen(eigval_tmp, eigvec, C, kk, "lr", (tol / Datum<T>::sqrt2));
  158. if(status == false)
  159. {
  160. U.soft_reset();
  161. S.soft_reset();
  162. V.soft_reset();
  163. return false;
  164. }
  165. const Col<T> eigval = real(eigval_tmp);
  166. const T A_norm = max(eigval);
  167. const T tol2 = tol / Datum<T>::sqrt2 * A_norm;
  168. uvec indices = find(eigval > tol2);
  169. if(indices.n_elem > kk)
  170. {
  171. indices = indices.subvec(0,kk-1);
  172. }
  173. else
  174. if(indices.n_elem < kk)
  175. {
  176. const uvec indices2 = find(abs(eigval) <= tol2);
  177. const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) );
  178. if(N_extra > 0) { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); }
  179. }
  180. const uvec sorted_indices = sort_index(eigval, "descend");
  181. S = eigval.elem(sorted_indices); S *= A_max;
  182. if(calc_UV)
  183. {
  184. uvec U_row_indices(A.n_rows); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; }
  185. uvec V_row_indices(A.n_cols); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; }
  186. U = Datum<T>::sqrt2 * eigvec(U_row_indices, sorted_indices);
  187. V = Datum<T>::sqrt2 * eigvec(V_row_indices, sorted_indices);
  188. }
  189. }
  190. if(S.n_elem < k) { arma_debug_warn("svds(): found fewer singular values than specified"); }
  191. return true;
  192. }
  193. //! find the k largest singular values and corresponding singular vectors of sparse matrix X
  194. template<typename T1>
  195. inline
  196. bool
  197. svds
  198. (
  199. Mat<typename T1::elem_type>& U,
  200. Col<typename T1::pod_type >& S,
  201. Mat<typename T1::elem_type>& V,
  202. const SpBase<typename T1::elem_type,T1>& X,
  203. const uword k,
  204. const typename T1::pod_type tol = 0.0,
  205. const typename arma_real_or_cx_only<typename T1::elem_type>::result* junk = 0
  206. )
  207. {
  208. arma_extra_debug_sigprint();
  209. arma_ignore(junk);
  210. const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, true);
  211. if(status == false) { arma_debug_warn("svds(): decomposition failed"); }
  212. return status;
  213. }
  214. //! find the k largest singular values of sparse matrix X
  215. template<typename T1>
  216. inline
  217. bool
  218. svds
  219. (
  220. Col<typename T1::pod_type >& S,
  221. const SpBase<typename T1::elem_type,T1>& X,
  222. const uword k,
  223. const typename T1::pod_type tol = 0.0,
  224. const typename arma_real_or_cx_only<typename T1::elem_type>::result* junk = 0
  225. )
  226. {
  227. arma_extra_debug_sigprint();
  228. arma_ignore(junk);
  229. Mat<typename T1::elem_type> U;
  230. Mat<typename T1::elem_type> V;
  231. const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false);
  232. if(status == false) { arma_debug_warn("svds(): decomposition failed"); }
  233. return status;
  234. }
  235. //! find the k largest singular values of sparse matrix X
  236. template<typename T1>
  237. arma_warn_unused
  238. inline
  239. Col<typename T1::pod_type>
  240. svds
  241. (
  242. const SpBase<typename T1::elem_type,T1>& X,
  243. const uword k,
  244. const typename T1::pod_type tol = 0.0,
  245. const typename arma_real_or_cx_only<typename T1::elem_type>::result* junk = 0
  246. )
  247. {
  248. arma_extra_debug_sigprint();
  249. arma_ignore(junk);
  250. Col<typename T1::pod_type> S;
  251. Mat<typename T1::elem_type> U;
  252. Mat<typename T1::elem_type> V;
  253. const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false);
  254. if(status == false) { arma_stop_runtime_error("svds(): decomposition failed"); }
  255. return S;
  256. }
  257. //! @}