op_var_meat.hpp 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup op_var
  16. //! @{
  17. //! \brief
  18. //! For each row or for each column, find the variance.
  19. //! The result is stored in a dense matrix that has either one column or one row.
  20. //! The dimension, for which the variances are found, is set via the var() function.
  21. template<typename T1>
  22. inline
  23. void
  24. op_var::apply(Mat<typename T1::pod_type>& out, const mtOp<typename T1::pod_type, T1, op_var>& in)
  25. {
  26. arma_extra_debug_sigprint();
  27. typedef typename T1::elem_type in_eT;
  28. typedef typename T1::pod_type out_eT;
  29. const unwrap_check_mixed<T1> tmp(in.m, out);
  30. const Mat<in_eT>& X = tmp.M;
  31. const uword norm_type = in.aux_uword_a;
  32. const uword dim = in.aux_uword_b;
  33. arma_debug_check( (norm_type > 1), "var(): parameter 'norm_type' must be 0 or 1" );
  34. arma_debug_check( (dim > 1), "var(): parameter 'dim' must be 0 or 1" );
  35. const uword X_n_rows = X.n_rows;
  36. const uword X_n_cols = X.n_cols;
  37. if(dim == 0)
  38. {
  39. arma_extra_debug_print("op_var::apply(): dim = 0");
  40. out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols);
  41. if(X_n_rows > 0)
  42. {
  43. out_eT* out_mem = out.memptr();
  44. for(uword col=0; col<X_n_cols; ++col)
  45. {
  46. out_mem[col] = op_var::direct_var( X.colptr(col), X_n_rows, norm_type );
  47. }
  48. }
  49. }
  50. else
  51. if(dim == 1)
  52. {
  53. arma_extra_debug_print("op_var::apply(): dim = 1");
  54. out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0);
  55. if(X_n_cols > 0)
  56. {
  57. podarray<in_eT> dat(X_n_cols);
  58. in_eT* dat_mem = dat.memptr();
  59. out_eT* out_mem = out.memptr();
  60. for(uword row=0; row<X_n_rows; ++row)
  61. {
  62. dat.copy_row(X, row);
  63. out_mem[row] = op_var::direct_var( dat_mem, X_n_cols, norm_type );
  64. }
  65. }
  66. }
  67. }
  68. template<typename T1>
  69. inline
  70. typename T1::pod_type
  71. op_var::var_vec(const Base<typename T1::elem_type, T1>& X, const uword norm_type)
  72. {
  73. arma_extra_debug_sigprint();
  74. arma_debug_check( (norm_type > 1), "var(): parameter 'norm_type' must be 0 or 1" );
  75. const quasi_unwrap<T1> U(X.get_ref());
  76. return op_var::direct_var(U.M.memptr(), U.M.n_elem, norm_type);
  77. }
  78. template<typename eT>
  79. inline
  80. typename get_pod_type<eT>::result
  81. op_var::var_vec(const subview_col<eT>& X, const uword norm_type)
  82. {
  83. arma_extra_debug_sigprint();
  84. arma_debug_check( (norm_type > 1), "var(): parameter 'norm_type' must be 0 or 1" );
  85. return op_var::direct_var(X.colptr(0), X.n_rows, norm_type);
  86. }
  87. template<typename eT>
  88. inline
  89. typename get_pod_type<eT>::result
  90. op_var::var_vec(const subview_row<eT>& X, const uword norm_type)
  91. {
  92. arma_extra_debug_sigprint();
  93. arma_debug_check( (norm_type > 1), "var(): parameter 'norm_type' must be 0 or 1" );
  94. const Mat<eT>& A = X.m;
  95. const uword start_row = X.aux_row1;
  96. const uword start_col = X.aux_col1;
  97. const uword end_col_p1 = start_col + X.n_cols;
  98. podarray<eT> tmp(X.n_elem);
  99. eT* tmp_mem = tmp.memptr();
  100. for(uword i=0, col=start_col; col < end_col_p1; ++col, ++i)
  101. {
  102. tmp_mem[i] = A.at(start_row, col);
  103. }
  104. return op_var::direct_var(tmp.memptr(), tmp.n_elem, norm_type);
  105. }
  106. //! find the variance of an array
  107. template<typename eT>
  108. inline
  109. eT
  110. op_var::direct_var(const eT* const X, const uword n_elem, const uword norm_type)
  111. {
  112. arma_extra_debug_sigprint();
  113. if(n_elem >= 2)
  114. {
  115. const eT acc1 = op_mean::direct_mean(X, n_elem);
  116. eT acc2 = eT(0);
  117. eT acc3 = eT(0);
  118. uword i,j;
  119. for(i=0, j=1; j<n_elem; i+=2, j+=2)
  120. {
  121. const eT Xi = X[i];
  122. const eT Xj = X[j];
  123. const eT tmpi = acc1 - Xi;
  124. const eT tmpj = acc1 - Xj;
  125. acc2 += tmpi*tmpi + tmpj*tmpj;
  126. acc3 += tmpi + tmpj;
  127. }
  128. if(i < n_elem)
  129. {
  130. const eT Xi = X[i];
  131. const eT tmpi = acc1 - Xi;
  132. acc2 += tmpi*tmpi;
  133. acc3 += tmpi;
  134. }
  135. const eT norm_val = (norm_type == 0) ? eT(n_elem-1) : eT(n_elem);
  136. const eT var_val = (acc2 - acc3*acc3/eT(n_elem)) / norm_val;
  137. return arma_isfinite(var_val) ? var_val : op_var::direct_var_robust(X, n_elem, norm_type);
  138. }
  139. else
  140. {
  141. return eT(0);
  142. }
  143. }
  144. //! find the variance of an array (robust but slow)
  145. template<typename eT>
  146. inline
  147. eT
  148. op_var::direct_var_robust(const eT* const X, const uword n_elem, const uword norm_type)
  149. {
  150. arma_extra_debug_sigprint();
  151. if(n_elem > 1)
  152. {
  153. eT r_mean = X[0];
  154. eT r_var = eT(0);
  155. for(uword i=1; i<n_elem; ++i)
  156. {
  157. const eT tmp = X[i] - r_mean;
  158. const eT i_plus_1 = eT(i+1);
  159. r_var = eT(i-1)/eT(i) * r_var + (tmp*tmp)/i_plus_1;
  160. r_mean = r_mean + tmp/i_plus_1;
  161. }
  162. return (norm_type == 0) ? r_var : (eT(n_elem-1)/eT(n_elem)) * r_var;
  163. }
  164. else
  165. {
  166. return eT(0);
  167. }
  168. }
  169. //! find the variance of an array (version for complex numbers)
  170. template<typename T>
  171. inline
  172. T
  173. op_var::direct_var(const std::complex<T>* const X, const uword n_elem, const uword norm_type)
  174. {
  175. arma_extra_debug_sigprint();
  176. typedef typename std::complex<T> eT;
  177. if(n_elem >= 2)
  178. {
  179. const eT acc1 = op_mean::direct_mean(X, n_elem);
  180. T acc2 = T(0);
  181. eT acc3 = eT(0);
  182. for(uword i=0; i<n_elem; ++i)
  183. {
  184. const eT tmp = acc1 - X[i];
  185. acc2 += std::norm(tmp);
  186. acc3 += tmp;
  187. }
  188. const T norm_val = (norm_type == 0) ? T(n_elem-1) : T(n_elem);
  189. const T var_val = (acc2 - std::norm(acc3)/T(n_elem)) / norm_val;
  190. return arma_isfinite(var_val) ? var_val : op_var::direct_var_robust(X, n_elem, norm_type);
  191. }
  192. else
  193. {
  194. return T(0);
  195. }
  196. }
  197. //! find the variance of an array (version for complex numbers) (robust but slow)
  198. template<typename T>
  199. inline
  200. T
  201. op_var::direct_var_robust(const std::complex<T>* const X, const uword n_elem, const uword norm_type)
  202. {
  203. arma_extra_debug_sigprint();
  204. typedef typename std::complex<T> eT;
  205. if(n_elem > 1)
  206. {
  207. eT r_mean = X[0];
  208. T r_var = T(0);
  209. for(uword i=1; i<n_elem; ++i)
  210. {
  211. const eT tmp = X[i] - r_mean;
  212. const T i_plus_1 = T(i+1);
  213. r_var = T(i-1)/T(i) * r_var + std::norm(tmp)/i_plus_1;
  214. r_mean = r_mean + tmp/i_plus_1;
  215. }
  216. return (norm_type == 0) ? r_var : (T(n_elem-1)/T(n_elem)) * r_var;
  217. }
  218. else
  219. {
  220. return T(0);
  221. }
  222. }
  223. //! @}