band_helper.hpp 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup band_helper
  16. //! @{
  17. namespace band_helper
  18. {
  19. template<typename eT>
  20. inline
  21. bool
  22. is_band(uword& out_KL, uword& out_KU, const Mat<eT>& A, const uword N_min)
  23. {
  24. arma_extra_debug_sigprint();
  25. // NOTE: assuming that A has a square size
  26. // NOTE: assuming that N_min is >= 4
  27. const uword N = A.n_rows;
  28. if(N < N_min) { return false; }
  29. // first, quickly check bottom-left and top-right corners
  30. const eT eT_zero = eT(0);
  31. const eT* A_col0 = A.memptr();
  32. const eT* A_col1 = A_col0 + N;
  33. if( (A_col0[N-2] != eT_zero) || (A_col0[N-1] != eT_zero) || (A_col1[N-2] != eT_zero) || (A_col1[N-1] != eT_zero) ) { return false; }
  34. const eT* A_colNm2 = A.colptr(N-2);
  35. const eT* A_colNm1 = A_colNm2 + N;
  36. if( (A_colNm2[0] != eT_zero) || (A_colNm2[1] != eT_zero) || (A_colNm1[0] != eT_zero) || (A_colNm1[1] != eT_zero) ) { return false; }
  37. // if we reached this point, go through the entire matrix to work out number of subdiagonals and superdiagonals
  38. const uword n_nonzero_threshold = (N*N)/4; // empirically determined
  39. uword KL = 0; // number of subdiagonals
  40. uword KU = 0; // number of superdiagonals
  41. const eT* A_colptr = A.memptr();
  42. for(uword col=0; col < N; ++col)
  43. {
  44. uword first_nonzero_row = col;
  45. uword last_nonzero_row = col;
  46. for(uword row=0; row < col; ++row)
  47. {
  48. if( A_colptr[row] != eT_zero ) { first_nonzero_row = row; break; }
  49. }
  50. for(uword row=(col+1); row < N; ++row)
  51. {
  52. last_nonzero_row = (A_colptr[row] != eT_zero) ? row : last_nonzero_row;
  53. }
  54. const uword L_count = last_nonzero_row - col;
  55. const uword U_count = col - first_nonzero_row;
  56. if( (L_count > KL) || (U_count > KU) )
  57. {
  58. KL = (std::max)(KL, L_count);
  59. KU = (std::max)(KU, U_count);
  60. const uword n_nonzero = N*(KL+KU+1) - (KL*(KL+1) + KU*(KU+1))/2;
  61. // return as soon as we know that it's not worth analysing the matrix any further
  62. if(n_nonzero > n_nonzero_threshold) { return false; }
  63. }
  64. A_colptr += N;
  65. }
  66. out_KL = KL;
  67. out_KU = KU;
  68. return true;
  69. }
  70. template<typename eT>
  71. inline
  72. bool
  73. is_band_lower(uword& out_KD, const Mat<eT>& A, const uword N_min)
  74. {
  75. arma_extra_debug_sigprint();
  76. // NOTE: assuming that A has a square size
  77. // NOTE: assuming that N_min is >= 4
  78. const uword N = A.n_rows;
  79. if(N < N_min) { return false; }
  80. // first, quickly check bottom-left corner
  81. const eT eT_zero = eT(0);
  82. const eT* A_col0 = A.memptr();
  83. const eT* A_col1 = A_col0 + N;
  84. if( (A_col0[N-2] != eT_zero) || (A_col0[N-1] != eT_zero) || (A_col1[N-2] != eT_zero) || (A_col1[N-1] != eT_zero) ) { return false; }
  85. // if we reached this point, go through the bottom triangle to work out number of subdiagonals
  86. const uword n_nonzero_threshold = ( N*N - (N*(N-1))/2 ) / 4; // empirically determined
  87. uword KL = 0; // number of subdiagonals
  88. const eT* A_colptr = A.memptr();
  89. for(uword col=0; col < N; ++col)
  90. {
  91. uword last_nonzero_row = col;
  92. for(uword row=(col+1); row < N; ++row)
  93. {
  94. last_nonzero_row = (A_colptr[row] != eT_zero) ? row : last_nonzero_row;
  95. }
  96. const uword L_count = last_nonzero_row - col;
  97. if(L_count > KL)
  98. {
  99. KL = L_count;
  100. const uword n_nonzero = N*(KL+1) - (KL*(KL+1))/2;
  101. // return as soon as we know that it's not worth analysing the matrix any further
  102. if(n_nonzero > n_nonzero_threshold) { return false; }
  103. }
  104. A_colptr += N;
  105. }
  106. out_KD = KL;
  107. return true;
  108. }
  109. template<typename eT>
  110. inline
  111. bool
  112. is_band_upper(uword& out_KD, const Mat<eT>& A, const uword N_min)
  113. {
  114. arma_extra_debug_sigprint();
  115. // NOTE: assuming that A has a square size
  116. // NOTE: assuming that N_min is >= 4
  117. const uword N = A.n_rows;
  118. if(N < N_min) { return false; }
  119. // first, quickly check top-right corner
  120. const eT eT_zero = eT(0);
  121. const eT* A_colNm2 = A.colptr(N-2);
  122. const eT* A_colNm1 = A_colNm2 + N;
  123. if( (A_colNm2[0] != eT_zero) || (A_colNm2[1] != eT_zero) || (A_colNm1[0] != eT_zero) || (A_colNm1[1] != eT_zero) ) { return false; }
  124. // if we reached this point, go through the entire matrix to work out number of superdiagonals
  125. const uword n_nonzero_threshold = ( N*N - (N*(N-1))/2 ) / 4; // empirically determined
  126. uword KU = 0; // number of superdiagonals
  127. const eT* A_colptr = A.memptr();
  128. for(uword col=0; col < N; ++col)
  129. {
  130. uword first_nonzero_row = col;
  131. for(uword row=0; row < col; ++row)
  132. {
  133. if( A_colptr[row] != eT_zero ) { first_nonzero_row = row; break; }
  134. }
  135. const uword U_count = col - first_nonzero_row;
  136. if(U_count > KU)
  137. {
  138. KU = U_count;
  139. const uword n_nonzero = N*(KU+1) - (KU*(KU+1))/2;
  140. // return as soon as we know that it's not worth analysing the matrix any further
  141. if(n_nonzero > n_nonzero_threshold) { return false; }
  142. }
  143. A_colptr += N;
  144. }
  145. out_KD = KU;
  146. return true;
  147. }
  148. template<typename eT>
  149. inline
  150. void
  151. compress(Mat<eT>& AB, const Mat<eT>& A, const uword KL, const uword KU, const bool use_offset)
  152. {
  153. arma_extra_debug_sigprint();
  154. // NOTE: assuming that A has a square size
  155. // band matrix storage format
  156. // http://www.netlib.org/lapack/lug/node124.html
  157. // for ?gbsv, matrix AB size: 2*KL+KU+1 x N; band representation of A stored in rows KL+1 to 2*KL+KU+1 (note: fortran counts from 1)
  158. // for ?gbsvx, matrix AB size: KL+KU+1 x N; band representaiton of A stored in rows 1 to KL+KU+1 (note: fortran counts from 1)
  159. //
  160. // the +1 in the above formulas is to take into account the main diagonal
  161. const uword AB_n_rows = (use_offset) ? uword(2*KL + KU + 1) : uword(KL + KU + 1);
  162. const uword N = A.n_rows;
  163. AB.set_size(AB_n_rows, N);
  164. if(A.is_empty()) { AB.zeros(); return; }
  165. if(AB_n_rows == uword(1))
  166. {
  167. eT* AB_mem = AB.memptr();
  168. for(uword i=0; i<N; ++i) { AB_mem[i] = A.at(i,i); }
  169. }
  170. else
  171. {
  172. AB.zeros(); // paranoia
  173. for(uword j=0; j < N; ++j)
  174. {
  175. const uword A_row_start = (j > KU) ? uword(j - KU) : uword(0);
  176. const uword A_row_endp1 = (std::min)(N, j+KL+1);
  177. const uword length = A_row_endp1 - A_row_start;
  178. const uword AB_row_start = (KU > j) ? (KU - j) : uword(0);
  179. const eT* A_colptr = A.colptr(j) + A_row_start;
  180. eT* AB_colptr = AB.colptr(j) + AB_row_start + ( (use_offset) ? KL : uword(0) );
  181. arrayops::copy( AB_colptr, A_colptr, length );
  182. }
  183. }
  184. }
  185. template<typename eT>
  186. inline
  187. void
  188. uncompress(Mat<eT>& A, const Mat<eT>& AB, const uword KL, const uword KU, const bool use_offset)
  189. {
  190. arma_extra_debug_sigprint();
  191. const uword AB_n_rows = AB.n_rows;
  192. const uword N = AB.n_cols;
  193. arma_debug_check( (AB_n_rows != ((use_offset) ? uword(2*KL + KU + 1) : uword(KL + KU + 1))), "band_helper::uncompress(): detected inconsistency" );
  194. A.zeros(N,N); // assuming there is no aliasing between A and AB
  195. if(AB_n_rows == uword(1))
  196. {
  197. const eT* AB_mem = AB.memptr();
  198. for(uword i=0; i<N; ++i) { A.at(i,i) = AB_mem[i]; }
  199. }
  200. else
  201. {
  202. for(uword j=0; j < N; ++j)
  203. {
  204. const uword A_row_start = (j > KU) ? uword(j - KU) : uword(0);
  205. const uword A_row_endp1 = (std::min)(N, j+KL+1);
  206. const uword length = A_row_endp1 - A_row_start;
  207. const uword AB_row_start = (KU > j) ? (KU - j) : uword(0);
  208. const eT* AB_colptr = AB.colptr(j) + AB_row_start + ( (use_offset) ? KL : uword(0) );
  209. eT* A_colptr = A.colptr(j) + A_row_start;
  210. arrayops::copy( A_colptr, AB_colptr, length );
  211. }
  212. }
  213. }
  214. template<typename eT>
  215. inline
  216. void
  217. extract_tridiag(Mat<eT>& out, const Mat<eT>& A)
  218. {
  219. arma_extra_debug_sigprint();
  220. // NOTE: assuming that A has a square size and is at least 2x2
  221. const uword N = A.n_rows;
  222. out.set_size(N, 3); // assuming there is no aliasing between 'out' and 'A'
  223. if(N < 2) { return; }
  224. eT* DL = out.colptr(0);
  225. eT* DD = out.colptr(1);
  226. eT* DU = out.colptr(2);
  227. DD[0] = A[0];
  228. DL[0] = A[1];
  229. const uword Nm1 = N-1;
  230. const uword Nm2 = N-2;
  231. for(uword i=0; i < Nm2; ++i)
  232. {
  233. const uword ip1 = i+1;
  234. const eT* data = &(A.at(i, ip1));
  235. const eT tmp0 = data[0];
  236. const eT tmp1 = data[1];
  237. const eT tmp2 = data[2];
  238. DL[ip1] = tmp2;
  239. DD[ip1] = tmp1;
  240. DU[i ] = tmp0;
  241. }
  242. const eT* data = &(A.at(Nm2, Nm1));
  243. DL[Nm1] = 0;
  244. DU[Nm2] = data[0];
  245. DU[Nm1] = 0;
  246. DD[Nm1] = data[1];
  247. }
  248. } // end of namespace band_helper
  249. //! @}