fn_find.hpp 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup fn_find
  16. //! @{
  17. template<typename T1>
  18. arma_warn_unused
  19. inline
  20. typename
  21. enable_if2
  22. <
  23. is_arma_type<T1>::value,
  24. const mtOp<uword, T1, op_find_simple>
  25. >::result
  26. find(const T1& X)
  27. {
  28. arma_extra_debug_sigprint();
  29. return mtOp<uword, T1, op_find_simple>(X);
  30. }
  31. template<typename T1>
  32. arma_warn_unused
  33. inline
  34. const mtOp<uword, T1, op_find>
  35. find(const Base<typename T1::elem_type,T1>& X, const uword k, const char* direction = "first")
  36. {
  37. arma_extra_debug_sigprint();
  38. const char sig = (direction != NULL) ? direction[0] : char(0);
  39. arma_debug_check
  40. (
  41. ( (sig != 'f') && (sig != 'F') && (sig != 'l') && (sig != 'L') ),
  42. "find(): direction must be \"first\" or \"last\""
  43. );
  44. const uword type = ( (sig == 'f') || (sig == 'F') ) ? 0 : 1;
  45. return mtOp<uword, T1, op_find>(X.get_ref(), k, type);
  46. }
  47. //
  48. template<typename T1>
  49. arma_warn_unused
  50. inline
  51. uvec
  52. find(const BaseCube<typename T1::elem_type,T1>& X)
  53. {
  54. arma_extra_debug_sigprint();
  55. typedef typename T1::elem_type eT;
  56. const unwrap_cube<T1> tmp(X.get_ref());
  57. const Mat<eT> R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false );
  58. return find(R);
  59. }
  60. template<typename T1>
  61. arma_warn_unused
  62. inline
  63. uvec
  64. find(const BaseCube<typename T1::elem_type,T1>& X, const uword k, const char* direction = "first")
  65. {
  66. arma_extra_debug_sigprint();
  67. typedef typename T1::elem_type eT;
  68. const unwrap_cube<T1> tmp(X.get_ref());
  69. const Mat<eT> R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false );
  70. return find(R, k, direction);
  71. }
  72. template<typename T1, typename op_rel_type>
  73. arma_warn_unused
  74. inline
  75. uvec
  76. find(const mtOpCube<uword, T1, op_rel_type>& X, const uword k = 0, const char* direction = "first")
  77. {
  78. arma_extra_debug_sigprint();
  79. typedef typename T1::elem_type eT;
  80. const unwrap_cube<T1> tmp(X.m);
  81. const Mat<eT> R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false );
  82. return find( mtOp<uword, Mat<eT>, op_rel_type>(R, X.aux), k, direction );
  83. }
  84. template<typename T1, typename T2, typename glue_rel_type>
  85. arma_warn_unused
  86. inline
  87. uvec
  88. find(const mtGlueCube<uword, T1, T2, glue_rel_type>& X, const uword k = 0, const char* direction = "first")
  89. {
  90. arma_extra_debug_sigprint();
  91. typedef typename T1::elem_type eT1;
  92. typedef typename T2::elem_type eT2;
  93. const unwrap_cube<T1> tmp1(X.A);
  94. const unwrap_cube<T2> tmp2(X.B);
  95. arma_debug_assert_same_size( tmp1.M, tmp2.M, "relational operator" );
  96. const Mat<eT1> R1( const_cast< eT1* >(tmp1.M.memptr()), tmp1.M.n_elem, 1, false );
  97. const Mat<eT2> R2( const_cast< eT2* >(tmp2.M.memptr()), tmp2.M.n_elem, 1, false );
  98. return find( mtGlue<uword, Mat<eT1>, Mat<eT2>, glue_rel_type>(R1, R2), k, direction );
  99. }
  100. //
  101. template<typename T1>
  102. arma_warn_unused
  103. inline
  104. Col<uword>
  105. find(const SpBase<typename T1::elem_type,T1>& X, const uword k = 0)
  106. {
  107. arma_extra_debug_sigprint();
  108. const SpProxy<T1> P(X.get_ref());
  109. const uword n_rows = P.get_n_rows();
  110. const uword n_nz = P.get_n_nonzero();
  111. Mat<uword> tmp(n_nz,1);
  112. uword* tmp_mem = tmp.memptr();
  113. typename SpProxy<T1>::const_iterator_type it = P.begin();
  114. for(uword i=0; i<n_nz; ++i)
  115. {
  116. const uword index = it.row() + it.col()*n_rows;
  117. tmp_mem[i] = index;
  118. ++it;
  119. }
  120. Col<uword> out;
  121. const uword count = (k == 0) ? uword(n_nz) : uword( (std::min)(n_nz, k) );
  122. out.steal_mem_col(tmp, count);
  123. return out;
  124. }
  125. template<typename T1>
  126. arma_warn_unused
  127. inline
  128. Col<uword>
  129. find(const SpBase<typename T1::elem_type,T1>& X, const uword k, const char* direction)
  130. {
  131. arma_extra_debug_sigprint();
  132. arma_check(true, "find(SpBase,k,direction): not implemented yet"); // TODO
  133. Col<uword> out;
  134. return out;
  135. }
  136. //
  137. template<typename T1>
  138. arma_warn_unused
  139. inline
  140. typename
  141. enable_if2
  142. <
  143. is_arma_type<T1>::value,
  144. const mtOp<uword, T1, op_find_finite>
  145. >::result
  146. find_finite(const T1& X)
  147. {
  148. arma_extra_debug_sigprint();
  149. return mtOp<uword, T1, op_find_finite>(X);
  150. }
  151. template<typename T1>
  152. arma_warn_unused
  153. inline
  154. typename
  155. enable_if2
  156. <
  157. is_arma_type<T1>::value,
  158. const mtOp<uword, T1, op_find_nonfinite>
  159. >::result
  160. find_nonfinite(const T1& X)
  161. {
  162. arma_extra_debug_sigprint();
  163. return mtOp<uword, T1, op_find_nonfinite>(X);
  164. }
  165. //
  166. template<typename T1>
  167. arma_warn_unused
  168. inline
  169. uvec
  170. find_finite(const BaseCube<typename T1::elem_type,T1>& X)
  171. {
  172. arma_extra_debug_sigprint();
  173. typedef typename T1::elem_type eT;
  174. const unwrap_cube<T1> tmp(X.get_ref());
  175. const Mat<eT> R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false );
  176. return find_finite(R);
  177. }
  178. template<typename T1>
  179. arma_warn_unused
  180. inline
  181. uvec
  182. find_nonfinite(const BaseCube<typename T1::elem_type,T1>& X)
  183. {
  184. arma_extra_debug_sigprint();
  185. typedef typename T1::elem_type eT;
  186. const unwrap_cube<T1> tmp(X.get_ref());
  187. const Mat<eT> R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false );
  188. return find_nonfinite(R);
  189. }
  190. //
  191. template<typename T1>
  192. arma_warn_unused
  193. inline
  194. Col<uword>
  195. find_finite(const SpBase<typename T1::elem_type,T1>& X)
  196. {
  197. arma_extra_debug_sigprint();
  198. const SpProxy<T1> P(X.get_ref());
  199. const uword n_rows = P.get_n_rows();
  200. const uword n_nz = P.get_n_nonzero();
  201. Mat<uword> tmp(n_nz,1);
  202. uword* tmp_mem = tmp.memptr();
  203. typename SpProxy<T1>::const_iterator_type it = P.begin();
  204. uword count = 0;
  205. for(uword i=0; i<n_nz; ++i)
  206. {
  207. if(arma_isfinite(*it))
  208. {
  209. const uword index = it.row() + it.col()*n_rows;
  210. tmp_mem[count] = index;
  211. ++count;
  212. }
  213. ++it;
  214. }
  215. Col<uword> out;
  216. if(count > 0) { out.steal_mem_col(tmp, count); }
  217. return out;
  218. }
  219. template<typename T1>
  220. arma_warn_unused
  221. inline
  222. Col<uword>
  223. find_nonfinite(const SpBase<typename T1::elem_type,T1>& X)
  224. {
  225. arma_extra_debug_sigprint();
  226. const SpProxy<T1> P(X.get_ref());
  227. const uword n_rows = P.get_n_rows();
  228. const uword n_nz = P.get_n_nonzero();
  229. Mat<uword> tmp(n_nz,1);
  230. uword* tmp_mem = tmp.memptr();
  231. typename SpProxy<T1>::const_iterator_type it = P.begin();
  232. uword count = 0;
  233. for(uword i=0; i<n_nz; ++i)
  234. {
  235. if(arma_isfinite(*it) == false)
  236. {
  237. const uword index = it.row() + it.col()*n_rows;
  238. tmp_mem[count] = index;
  239. ++count;
  240. }
  241. ++it;
  242. }
  243. Col<uword> out;
  244. if(count > 0) { out.steal_mem_col(tmp, count); }
  245. return out;
  246. }
  247. //! @}