newarp_GenEigsSolver_meat.hpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. namespace newarp
  16. {
  17. template<typename eT, int SelectionRule, typename OpType>
  18. inline
  19. void
  20. GenEigsSolver<eT, SelectionRule, OpType>::factorise_from(uword from_k, uword to_m, const Col<eT>& fk)
  21. {
  22. arma_extra_debug_sigprint();
  23. if(to_m <= from_k) { return; }
  24. fac_f = fk;
  25. Col<eT> w(dim_n);
  26. eT beta = norm(fac_f);
  27. // Keep the upperleft k x k submatrix of H and set other elements to 0
  28. fac_H.tail_cols(ncv - from_k).zeros();
  29. fac_H.submat(span(from_k, ncv - 1), span(0, from_k - 1)).zeros();
  30. for(uword i = from_k; i <= to_m - 1; i++)
  31. {
  32. bool restart = false;
  33. // If beta = 0, then the next V is not full rank
  34. // We need to generate a new residual vector that is orthogonal
  35. // to the current V, which we call a restart
  36. if(beta < eps)
  37. {
  38. // Generate new random vector for fac_f
  39. blas_int idist = 2;
  40. blas_int iseed[4] = {1, 3, 5, 7};
  41. iseed[0] = (i + 100) % 4095;
  42. blas_int n = dim_n;
  43. lapack::larnv(&idist, &iseed[0], &n, fac_f.memptr());
  44. // f <- f - V * V' * f, so that f is orthogonal to V
  45. Mat<eT> Vs(fac_V.memptr(), dim_n, i, false); // First i columns
  46. Col<eT> Vf = Vs.t() * fac_f;
  47. fac_f -= Vs * Vf;
  48. // beta <- ||f||
  49. beta = norm(fac_f);
  50. restart = true;
  51. }
  52. // v <- f / ||f||
  53. fac_V.col(i) = fac_f / beta; // The (i+1)-th column
  54. // Note that H[i+1, i] equals to the unrestarted beta
  55. if(restart) { fac_H(i, i - 1) = 0.0; } else { fac_H(i, i - 1) = beta; }
  56. // w <- A * v, v = fac_V.col(i)
  57. op.perform_op(fac_V.colptr(i), w.memptr());
  58. nmatop++;
  59. // First i+1 columns of V
  60. Mat<eT> Vs(fac_V.memptr(), dim_n, i + 1, false);
  61. // h = fac_H(0:i, i)
  62. Col<eT> h(fac_H.colptr(i), i + 1, false);
  63. // h <- V' * w
  64. h = Vs.t() * w;
  65. // f <- w - V * h
  66. fac_f = w - Vs * h;
  67. beta = norm(fac_f);
  68. if(beta > 0.717 * norm(h)) { continue; }
  69. // f/||f|| is going to be the next column of V, so we need to test
  70. // whether V' * (f/||f||) ~= 0
  71. Col<eT> Vf = Vs.t() * fac_f;
  72. // If not, iteratively correct the residual
  73. uword count = 0;
  74. while(count < 5 && abs(Vf).max() > approx0 * beta)
  75. {
  76. // f <- f - V * Vf
  77. fac_f -= Vs * Vf;
  78. // h <- h + Vf
  79. h += Vf;
  80. // beta <- ||f||
  81. beta = norm(fac_f);
  82. Vf = Vs.t() * fac_f;
  83. count++;
  84. }
  85. }
  86. }
  87. template<typename eT, int SelectionRule, typename OpType>
  88. inline
  89. void
  90. GenEigsSolver<eT, SelectionRule, OpType>::restart(uword k)
  91. {
  92. arma_extra_debug_sigprint();
  93. if(k >= ncv) { return; }
  94. DoubleShiftQR<eT> decomp_ds(ncv);
  95. UpperHessenbergQR<eT> decomp;
  96. Mat<eT> Q(ncv, ncv, fill::eye);
  97. for(uword i = k; i < ncv; i++)
  98. {
  99. if(cx_attrib::is_complex(ritz_val(i), eT(0)) && (i < (ncv - 1)) && cx_attrib::is_conj(ritz_val(i), ritz_val(i + 1), eT(0)))
  100. {
  101. // H - mu * I = Q1 * R1
  102. // H <- R1 * Q1 + mu * I = Q1' * H * Q1
  103. // H - conj(mu) * I = Q2 * R2
  104. // H <- R2 * Q2 + conj(mu) * I = Q2' * H * Q2
  105. //
  106. // (H - mu * I) * (H - conj(mu) * I) = Q1 * Q2 * R2 * R1 = Q * R
  107. eT s = 2 * ritz_val(i).real();
  108. eT t = std::norm(ritz_val(i));
  109. decomp_ds.compute(fac_H, s, t);
  110. // Q -> Q * Qi
  111. decomp_ds.apply_YQ(Q);
  112. // H -> Q'HQ
  113. fac_H = decomp_ds.matrix_QtHQ();
  114. i++;
  115. }
  116. else
  117. {
  118. // QR decomposition of H - mu * I, mu is real
  119. fac_H.diag() -= ritz_val(i).real();
  120. decomp.compute(fac_H);
  121. // Q -> Q * Qi
  122. decomp.apply_YQ(Q);
  123. // H -> Q'HQ = RQ + mu * I
  124. fac_H = decomp.matrix_RQ();
  125. fac_H.diag() += ritz_val(i).real();
  126. }
  127. }
  128. // V -> VQ
  129. // Q has some elements being zero
  130. // The first (ncv - k + i) elements of the i-th column of Q are non-zero
  131. Mat<eT> Vs(dim_n, k + 1);
  132. uword nnz;
  133. for(uword i = 0; i < k; i++)
  134. {
  135. nnz = ncv - k + i + 1;
  136. Mat<eT> V(fac_V.memptr(), dim_n, nnz, false);
  137. Col<eT> q(Q.colptr(i), nnz, false);
  138. Col<eT> v(Vs.colptr(i), dim_n, false);
  139. v = V * q;
  140. }
  141. Vs.col(k) = fac_V * Q.col(k);
  142. fac_V.head_cols(k + 1) = Vs;
  143. Col<eT> fk = fac_f * Q(ncv - 1, k - 1) + fac_V.col(k) * fac_H(k, k - 1);
  144. factorise_from(k, ncv, fk);
  145. retrieve_ritzpair();
  146. }
  147. template<typename eT, int SelectionRule, typename OpType>
  148. inline
  149. uword
  150. GenEigsSolver<eT, SelectionRule, OpType>::num_converged(eT tol)
  151. {
  152. arma_extra_debug_sigprint();
  153. // thresh = tol * max(prec, abs(theta)), theta for ritz value
  154. const eT f_norm = arma::norm(fac_f);
  155. for(uword i = 0; i < nev; i++)
  156. {
  157. eT thresh = tol * std::max(approx0, std::abs(ritz_val(i)));
  158. eT resid = std::abs(ritz_est(i)) * f_norm;
  159. ritz_conv[i] = (resid < thresh);
  160. }
  161. return std::count(ritz_conv.begin(), ritz_conv.end(), true);
  162. }
  163. template<typename eT, int SelectionRule, typename OpType>
  164. inline
  165. uword
  166. GenEigsSolver<eT, SelectionRule, OpType>::nev_adjusted(uword nconv)
  167. {
  168. arma_extra_debug_sigprint();
  169. uword nev_new = nev;
  170. for(uword i = nev; i < ncv; i++)
  171. {
  172. if(std::abs(ritz_est(i)) < eps) { nev_new++; }
  173. }
  174. // Adjust nev_new again, according to dnaup2.f line 660~674 in ARPACK
  175. nev_new += std::min(nconv, (ncv - nev_new) / 2);
  176. if(nev_new == 1 && ncv >= 6)
  177. {
  178. nev_new = ncv / 2;
  179. }
  180. else
  181. if(nev_new == 1 && ncv > 3)
  182. {
  183. nev_new = 2;
  184. }
  185. if(nev_new > ncv - 2) { nev_new = ncv - 2; }
  186. // Increase nev by one if ritz_val[nev - 1] and
  187. // ritz_val[nev] are conjugate pairs
  188. if(cx_attrib::is_complex(ritz_val(nev_new - 1), eps) && cx_attrib::is_conj(ritz_val(nev_new - 1), ritz_val(nev_new), eps))
  189. {
  190. nev_new++;
  191. }
  192. return nev_new;
  193. }
  194. template<typename eT, int SelectionRule, typename OpType>
  195. inline
  196. void
  197. GenEigsSolver<eT, SelectionRule, OpType>::retrieve_ritzpair()
  198. {
  199. arma_extra_debug_sigprint();
  200. UpperHessenbergEigen<eT> decomp(fac_H);
  201. Col< std::complex<eT> > evals = decomp.eigenvalues();
  202. Mat< std::complex<eT> > evecs = decomp.eigenvectors();
  203. SortEigenvalue< std::complex<eT>, SelectionRule > sorting(evals.memptr(), evals.n_elem);
  204. std::vector<uword> ind = sorting.index();
  205. // Copy the ritz values and vectors to ritz_val and ritz_vec, respectively
  206. for(uword i = 0; i < ncv; i++)
  207. {
  208. ritz_val(i) = evals(ind[i]);
  209. ritz_est(i) = evecs(ncv - 1, ind[i]);
  210. }
  211. for(uword i = 0; i < nev; i++)
  212. {
  213. ritz_vec.col(i) = evecs.col(ind[i]);
  214. }
  215. }
  216. template<typename eT, int SelectionRule, typename OpType>
  217. inline
  218. void
  219. GenEigsSolver<eT, SelectionRule, OpType>::sort_ritzpair()
  220. {
  221. arma_extra_debug_sigprint();
  222. // SortEigenvalue< std::complex<eT>, EigsSelect::LARGEST_MAGN > sorting(ritz_val.memptr(), nev);
  223. // sort Ritz values according to SelectionRule, to be consistent with ARPACK
  224. SortEigenvalue< std::complex<eT>, SelectionRule > sorting(ritz_val.memptr(), nev);
  225. std::vector<uword> ind = sorting.index();
  226. Col< std::complex<eT> > new_ritz_val(ncv);
  227. Mat< std::complex<eT> > new_ritz_vec(ncv, nev);
  228. std::vector<bool> new_ritz_conv(nev);
  229. for(uword i = 0; i < nev; i++)
  230. {
  231. new_ritz_val(i) = ritz_val(ind[i]);
  232. new_ritz_vec.col(i) = ritz_vec.col(ind[i]);
  233. new_ritz_conv[i] = ritz_conv[ind[i]];
  234. }
  235. ritz_val.swap(new_ritz_val);
  236. ritz_vec.swap(new_ritz_vec);
  237. ritz_conv.swap(new_ritz_conv);
  238. }
  239. template<typename eT, int SelectionRule, typename OpType>
  240. inline
  241. GenEigsSolver<eT, SelectionRule, OpType>::GenEigsSolver(const OpType& op_, uword nev_, uword ncv_)
  242. : op(op_)
  243. , nev(nev_)
  244. , dim_n(op.n_rows)
  245. , ncv(ncv_ > dim_n ? dim_n : ncv_)
  246. , nmatop(0)
  247. , niter(0)
  248. , eps(std::numeric_limits<eT>::epsilon())
  249. , approx0(std::pow(eps, eT(2.0) / 3))
  250. {
  251. arma_extra_debug_sigprint();
  252. arma_debug_check( (nev_ < 1 || nev_ > dim_n - 2), "newarp::GenEigsSolver: nev must satisfy 1 <= nev <= n - 2, n is the size of matrix" );
  253. arma_debug_check( (ncv_ < nev_ + 2 || ncv_ > dim_n), "newarp::GenEigsSolver: ncv must satisfy nev + 2 <= ncv <= n, n is the size of matrix" );
  254. }
  255. template<typename eT, int SelectionRule, typename OpType>
  256. inline
  257. void
  258. GenEigsSolver<eT, SelectionRule, OpType>::init(eT* init_resid)
  259. {
  260. arma_extra_debug_sigprint();
  261. // Reset all matrices/vectors to zero
  262. fac_V.zeros(dim_n, ncv);
  263. fac_H.zeros(ncv, ncv);
  264. fac_f.zeros(dim_n);
  265. ritz_val.zeros(ncv);
  266. ritz_vec.zeros(ncv, nev);
  267. ritz_est.zeros(ncv);
  268. ritz_conv.assign(nev, false);
  269. nmatop = 0;
  270. niter = 0;
  271. Col<eT> r(init_resid, dim_n, false);
  272. // The first column of fac_V
  273. Col<eT> v(fac_V.colptr(0), dim_n, false);
  274. eT rnorm = norm(r);
  275. arma_check( (rnorm < eps), "newarp::GenEigsSolver::init(): initial residual vector cannot be zero" );
  276. v = r / rnorm;
  277. Col<eT> w(dim_n);
  278. op.perform_op(v.memptr(), w.memptr());
  279. nmatop++;
  280. fac_H(0, 0) = dot(v, w);
  281. fac_f = w - v * fac_H(0, 0);
  282. }
  283. template<typename eT, int SelectionRule, typename OpType>
  284. inline
  285. void
  286. GenEigsSolver<eT, SelectionRule, OpType>::init()
  287. {
  288. arma_extra_debug_sigprint();
  289. podarray<eT> init_resid(dim_n);
  290. blas_int idist = 2; // Uniform(-1, 1)
  291. blas_int iseed[4] = {1, 3, 5, 7}; // Fixed random seed
  292. blas_int n = dim_n;
  293. lapack::larnv(&idist, &iseed[0], &n, init_resid.memptr());
  294. init(init_resid.memptr());
  295. }
  296. template<typename eT, int SelectionRule, typename OpType>
  297. inline
  298. uword
  299. GenEigsSolver<eT, SelectionRule, OpType>::compute(uword maxit, eT tol)
  300. {
  301. arma_extra_debug_sigprint();
  302. // The m-step Arnoldi factorisation
  303. factorise_from(1, ncv, fac_f);
  304. retrieve_ritzpair();
  305. // Restarting
  306. uword i, nconv = 0, nev_adj;
  307. for(i = 0; i < maxit; i++)
  308. {
  309. nconv = num_converged(tol);
  310. if(nconv >= nev) { break; }
  311. nev_adj = nev_adjusted(nconv);
  312. restart(nev_adj);
  313. }
  314. // Sorting results
  315. sort_ritzpair();
  316. niter = i + 1;
  317. return std::min(nev, nconv);
  318. }
  319. template<typename eT, int SelectionRule, typename OpType>
  320. inline
  321. Col< std::complex<eT> >
  322. GenEigsSolver<eT, SelectionRule, OpType>::eigenvalues()
  323. {
  324. arma_extra_debug_sigprint();
  325. uword nconv = std::count(ritz_conv.begin(), ritz_conv.end(), true);
  326. Col< std::complex<eT> > res(nconv);
  327. if(nconv > 0)
  328. {
  329. uword j = 0;
  330. for(uword i = 0; i < nev; i++)
  331. {
  332. if(ritz_conv[i])
  333. {
  334. res(j) = ritz_val(i);
  335. j++;
  336. }
  337. }
  338. }
  339. return res;
  340. }
  341. template<typename eT, int SelectionRule, typename OpType>
  342. inline
  343. Mat< std::complex<eT> >
  344. GenEigsSolver<eT, SelectionRule, OpType>::eigenvectors(uword nvec)
  345. {
  346. arma_extra_debug_sigprint();
  347. uword nconv = std::count(ritz_conv.begin(), ritz_conv.end(), true);
  348. nvec = std::min(nvec, nconv);
  349. Mat< std::complex<eT> > res(dim_n, nvec);
  350. if(nvec > 0)
  351. {
  352. Mat< std::complex<eT> > ritz_vec_conv(ncv, nvec);
  353. uword j = 0;
  354. for(uword i = 0; (i < nev) && (j < nvec); i++)
  355. {
  356. if(ritz_conv[i])
  357. {
  358. ritz_vec_conv.col(j) = ritz_vec.col(i);
  359. j++;
  360. }
  361. }
  362. res = fac_V * ritz_vec_conv;
  363. }
  364. return res;
  365. }
  366. } // namespace newarp