newarp_UpperHessenbergQR_meat.hpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. namespace newarp
  16. {
  17. template<typename eT>
  18. inline
  19. UpperHessenbergQR<eT>::UpperHessenbergQR()
  20. : n(0)
  21. , computed(false)
  22. {
  23. arma_extra_debug_sigprint();
  24. }
  25. template<typename eT>
  26. inline
  27. UpperHessenbergQR<eT>::UpperHessenbergQR(const Mat<eT>& mat_obj)
  28. : n(mat_obj.n_rows)
  29. , mat_T(n, n)
  30. , rot_cos(n - 1)
  31. , rot_sin(n - 1)
  32. , computed(false)
  33. {
  34. arma_extra_debug_sigprint();
  35. compute(mat_obj);
  36. }
  37. template<typename eT>
  38. void
  39. UpperHessenbergQR<eT>::compute(const Mat<eT>& mat_obj)
  40. {
  41. arma_extra_debug_sigprint();
  42. n = mat_obj.n_rows;
  43. mat_T.set_size(n, n);
  44. rot_cos.set_size(n - 1);
  45. rot_sin.set_size(n - 1);
  46. // Make a copy of mat_obj
  47. mat_T = mat_obj;
  48. eT xi, xj, r, c, s, eps = std::numeric_limits<eT>::epsilon();
  49. eT *ptr;
  50. for(uword i = 0; i < n - 1; i++)
  51. {
  52. // Make sure mat_T is upper Hessenberg
  53. // Zero the elements below mat_T(i + 1, i)
  54. if(i < n - 2) { mat_T(span(i + 2, n - 1), i).zeros(); }
  55. xi = mat_T(i, i); // mat_T(i, i)
  56. xj = mat_T(i + 1, i); // mat_T(i + 1, i)
  57. r = arma_hypot(xi, xj);
  58. if(r <= eps)
  59. {
  60. r = 0;
  61. rot_cos(i) = c = 1;
  62. rot_sin(i) = s = 0;
  63. }
  64. else
  65. {
  66. rot_cos(i) = c = xi / r;
  67. rot_sin(i) = s = -xj / r;
  68. }
  69. // For a complete QR decomposition,
  70. // we first obtain the rotation matrix
  71. // G = [ cos sin]
  72. // [-sin cos]
  73. // and then do T[i:(i + 1), i:(n - 1)] = G' * T[i:(i + 1), i:(n - 1)]
  74. // mat_T.submat(i, i, i + 1, n - 1) = Gt * mat_T.submat(i, i, i + 1, n - 1);
  75. mat_T(i, i) = r; // mat_T(i, i) => r
  76. mat_T(i + 1, i) = 0; // mat_T(i + 1, i) => 0
  77. ptr = &mat_T(i, i + 1); // mat_T(i, k), k = i+1, i+2, ..., n-1
  78. for(uword j = i + 1; j < n; j++, ptr += n)
  79. {
  80. eT tmp = ptr[0];
  81. ptr[0] = c * tmp - s * ptr[1];
  82. ptr[1] = s * tmp + c * ptr[1];
  83. }
  84. }
  85. computed = true;
  86. }
  87. template<typename eT>
  88. Mat<eT>
  89. UpperHessenbergQR<eT>::matrix_RQ()
  90. {
  91. arma_extra_debug_sigprint();
  92. arma_debug_check( (computed == false), "newarp::UpperHessenbergQR::matrix_RQ(): need to call compute() first" );
  93. // Make a copy of the R matrix
  94. Mat<eT> RQ = trimatu(mat_T);
  95. for(uword i = 0; i < n - 1; i++)
  96. {
  97. // RQ[, i:(i + 1)] = RQ[, i:(i + 1)] * Gi
  98. // Gi = [ cos[i] sin[i]]
  99. // [-sin[i] cos[i]]
  100. const eT c = rot_cos(i);
  101. const eT s = rot_sin(i);
  102. eT *Yi, *Yi1;
  103. Yi = RQ.colptr(i);
  104. Yi1 = RQ.colptr(i + 1);
  105. for(uword j = 0; j < i + 2; j++)
  106. {
  107. eT tmp = Yi[j];
  108. Yi[j] = c * tmp - s * Yi1[j];
  109. Yi1[j] = s * tmp + c * Yi1[j];
  110. }
  111. /* Yi = RQ(span(0, i + 1), i);
  112. RQ(span(0, i + 1), i) = (*c) * Yi - (*s) * RQ(span(0, i + 1), i + 1);
  113. RQ(span(0, i + 1), i + 1) = (*s) * Yi + (*c) * RQ(span(0, i + 1), i + 1); */
  114. }
  115. return RQ;
  116. }
  117. template<typename eT>
  118. inline
  119. void
  120. UpperHessenbergQR<eT>::apply_YQ(Mat<eT>& Y)
  121. {
  122. arma_extra_debug_sigprint();
  123. arma_debug_check( (computed == false), "newarp::UpperHessenbergQR::apply_YQ(): need to call compute() first" );
  124. eT *Y_col_i, *Y_col_i1;
  125. uword nrow = Y.n_rows;
  126. for(uword i = 0; i < n - 1; i++)
  127. {
  128. const eT c = rot_cos(i);
  129. const eT s = rot_sin(i);
  130. Y_col_i = Y.colptr(i);
  131. Y_col_i1 = Y.colptr(i + 1);
  132. for(uword j = 0; j < nrow; j++)
  133. {
  134. eT tmp = Y_col_i[j];
  135. Y_col_i[j] = c * tmp - s * Y_col_i1[j];
  136. Y_col_i1[j] = s * tmp + c * Y_col_i1[j];
  137. }
  138. }
  139. }
  140. template<typename eT>
  141. inline
  142. TridiagQR<eT>::TridiagQR()
  143. : UpperHessenbergQR<eT>()
  144. {
  145. arma_extra_debug_sigprint();
  146. }
  147. template<typename eT>
  148. inline
  149. TridiagQR<eT>::TridiagQR(const Mat<eT>& mat_obj)
  150. : UpperHessenbergQR<eT>()
  151. {
  152. arma_extra_debug_sigprint();
  153. this->compute(mat_obj);
  154. }
  155. template<typename eT>
  156. inline
  157. void
  158. TridiagQR<eT>::compute(const Mat<eT>& mat_obj)
  159. {
  160. arma_extra_debug_sigprint();
  161. this->n = mat_obj.n_rows;
  162. this->mat_T.set_size(this->n, this->n);
  163. this->rot_cos.set_size(this->n - 1);
  164. this->rot_sin.set_size(this->n - 1);
  165. this->mat_T.zeros();
  166. this->mat_T.diag() = mat_obj.diag();
  167. this->mat_T.diag(1) = mat_obj.diag(-1);
  168. this->mat_T.diag(-1) = mat_obj.diag(-1);
  169. eT xi, xj, r, c, s, tmp, eps = std::numeric_limits<eT>::epsilon();
  170. eT *ptr; // A number of pointers to avoid repeated address calculation
  171. for(uword i = 0; i < this->n - 1; i++)
  172. {
  173. xi = this->mat_T(i, i); // mat_T(i, i)
  174. xj = this->mat_T(i + 1, i); // mat_T(i + 1, i)
  175. r = arma_hypot(xi, xj);
  176. if(r <= eps)
  177. {
  178. r = 0;
  179. this->rot_cos(i) = c = 1;
  180. this->rot_sin(i) = s = 0;
  181. }
  182. else
  183. {
  184. this->rot_cos(i) = c = xi / r;
  185. this->rot_sin(i) = s = -xj / r;
  186. }
  187. // For a complete QR decomposition,
  188. // we first obtain the rotation matrix
  189. // G = [ cos sin]
  190. // [-sin cos]
  191. // and then do T[i:(i + 1), i:(i + 2)] = G' * T[i:(i + 1), i:(i + 2)]
  192. // Update T[i, i] and T[i + 1, i]
  193. // The updated value of T[i, i] is known to be r
  194. // The updated value of T[i + 1, i] is known to be 0
  195. this->mat_T(i, i) = r;
  196. this->mat_T(i + 1, i) = 0;
  197. // Update T[i, i + 1] and T[i + 1, i + 1]
  198. // ptr[0] == T[i, i + 1]
  199. // ptr[1] == T[i + 1, i + 1]
  200. ptr = &(this->mat_T(i, i + 1));
  201. tmp = *ptr;
  202. ptr[0] = c * tmp - s * ptr[1];
  203. ptr[1] = s * tmp + c * ptr[1];
  204. if(i < this->n - 2)
  205. {
  206. // Update T[i, i + 2] and T[i + 1, i + 2]
  207. // ptr[0] == T[i, i + 2] == 0
  208. // ptr[1] == T[i + 1, i + 2]
  209. ptr = &(this->mat_T(i, i + 2));
  210. ptr[0] = -s * ptr[1];
  211. ptr[1] *= c;
  212. }
  213. }
  214. this->computed = true;
  215. }
  216. template<typename eT>
  217. Mat<eT>
  218. TridiagQR<eT>::matrix_RQ()
  219. {
  220. arma_extra_debug_sigprint();
  221. arma_debug_check( (this->computed == false), "newarp::TridiagQR::matrix_RQ(): need to call compute() first" );
  222. // Make a copy of the R matrix
  223. Mat<eT> RQ(this->n, this->n, fill::zeros);
  224. RQ.diag() = this->mat_T.diag();
  225. RQ.diag(1) = this->mat_T.diag(1);
  226. // [m11 m12] will point to RQ[i:(i+1), i:(i+1)]
  227. // [m21 m22]
  228. eT *m11 = RQ.memptr(), *m12, *m21, *m22, tmp;
  229. for(uword i = 0; i < this->n - 1; i++)
  230. {
  231. const eT c = this->rot_cos(i);
  232. const eT s = this->rot_sin(i);
  233. m21 = m11 + 1;
  234. m12 = m11 + this->n;
  235. m22 = m12 + 1;
  236. tmp = *m21;
  237. // Update diagonal and the below-subdiagonal
  238. *m11 = c * (*m11) - s * (*m12);
  239. *m21 = c * tmp - s * (*m22);
  240. *m22 = s * tmp + c * (*m22);
  241. // Move m11 to RQ[i+1, i+1]
  242. m11 = m22;
  243. }
  244. // Copy the below-subdiagonal to above-subdiagonal
  245. RQ.diag(1) = RQ.diag(-1);
  246. return RQ;
  247. }
  248. } // namespace newarp