arma_forward.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. using std::cout;
  16. using std::cerr;
  17. using std::endl;
  18. using std::ios;
  19. using std::size_t;
  20. template<typename elem_type, typename derived> struct Base;
  21. template<typename elem_type, typename derived> struct BaseCube;
  22. template<typename eT> class Mat;
  23. template<typename eT> class Col;
  24. template<typename eT> class Row;
  25. template<typename eT> class Cube;
  26. template<typename eT> class xvec_htrans;
  27. template<typename oT> class field;
  28. template<typename eT, bool do_conj> class xtrans_mat;
  29. template<typename eT> class subview;
  30. template<typename eT> class subview_col;
  31. template<typename eT> class subview_row;
  32. template<typename eT> class subview_row_strans;
  33. template<typename eT> class subview_row_htrans;
  34. template<typename eT> class subview_cube;
  35. template<typename oT> class subview_field;
  36. template<typename eT> class SpValProxy;
  37. template<typename eT> class SpMat;
  38. template<typename eT> class SpCol;
  39. template<typename eT> class SpRow;
  40. template<typename eT> class SpSubview;
  41. template<typename eT> class SpSubview_col;
  42. template<typename eT> class SpSubview_row;
  43. template<typename eT> class diagview;
  44. template<typename eT> class spdiagview;
  45. template<typename eT> class MapMat;
  46. template<typename eT> class MapMat_val;
  47. template<typename eT> class SpMat_MapMat_val;
  48. template<typename eT> class SpSubview_MapMat_val;
  49. template<typename eT, typename T1> class subview_elem1;
  50. template<typename eT, typename T1, typename T2> class subview_elem2;
  51. template<typename parent, unsigned int mode> class subview_each1;
  52. template<typename parent, unsigned int mode, typename TB> class subview_each2;
  53. template<typename eT> class subview_cube_each1;
  54. template<typename eT, typename TB> class subview_cube_each2;
  55. template<typename eT, typename T1> class subview_cube_slices;
  56. class SizeMat;
  57. class SizeCube;
  58. class arma_empty_class {};
  59. class diskio;
  60. class op_strans;
  61. class op_htrans;
  62. class op_htrans2;
  63. class op_inv;
  64. class op_inv_sympd;
  65. class op_diagmat;
  66. class op_trimat;
  67. class op_vectorise_row;
  68. class op_vectorise_col;
  69. class glue_times;
  70. class glue_times_diag;
  71. class glue_rel_lt;
  72. class glue_rel_gt;
  73. class glue_rel_lteq;
  74. class glue_rel_gteq;
  75. class glue_rel_eq;
  76. class glue_rel_noteq;
  77. class glue_rel_and;
  78. class glue_rel_or;
  79. class op_rel_lt_pre;
  80. class op_rel_lt_post;
  81. class op_rel_gt_pre;
  82. class op_rel_gt_post;
  83. class op_rel_lteq_pre;
  84. class op_rel_lteq_post;
  85. class op_rel_gteq_pre;
  86. class op_rel_gteq_post;
  87. class op_rel_eq;
  88. class op_rel_noteq;
  89. class gen_eye;
  90. class gen_ones;
  91. class gen_zeros;
  92. class gen_randu;
  93. class gen_randn;
  94. class spop_strans;
  95. class spop_htrans;
  96. class spop_vectorise_row;
  97. class spop_vectorise_col;
  98. class spglue_plus;
  99. class spglue_minus;
  100. class spglue_schur;
  101. class spglue_times;
  102. class spglue_max;
  103. class spglue_min;
  104. class spglue_rel_lt;
  105. class spglue_rel_gt;
  106. class op_internal_equ;
  107. class op_internal_plus;
  108. class op_internal_minus;
  109. class op_internal_schur;
  110. class op_internal_div;
  111. struct traits_op_default
  112. {
  113. template<typename T1>
  114. struct traits
  115. {
  116. static const bool is_row = false;
  117. static const bool is_col = false;
  118. static const bool is_xvec = false;
  119. };
  120. };
  121. struct traits_op_xvec
  122. {
  123. template<typename T1>
  124. struct traits
  125. {
  126. static const bool is_row = false;
  127. static const bool is_col = false;
  128. static const bool is_xvec = true;
  129. };
  130. };
  131. struct traits_op_col
  132. {
  133. template<typename T1>
  134. struct traits
  135. {
  136. static const bool is_row = false;
  137. static const bool is_col = true;
  138. static const bool is_xvec = false;
  139. };
  140. };
  141. struct traits_op_row
  142. {
  143. template<typename T1>
  144. struct traits
  145. {
  146. static const bool is_row = true;
  147. static const bool is_col = false;
  148. static const bool is_xvec = false;
  149. };
  150. };
  151. struct traits_op_passthru
  152. {
  153. template<typename T1>
  154. struct traits
  155. {
  156. static const bool is_row = T1::is_row;
  157. static const bool is_col = T1::is_col;
  158. static const bool is_xvec = T1::is_xvec;
  159. };
  160. };
  161. struct traits_glue_default
  162. {
  163. template<typename T1, typename T2>
  164. struct traits
  165. {
  166. static const bool is_row = false;
  167. static const bool is_col = false;
  168. static const bool is_xvec = false;
  169. };
  170. };
  171. struct traits_glue_or
  172. {
  173. template<typename T1, typename T2>
  174. struct traits
  175. {
  176. static const bool is_row = (T1::is_row || T2::is_row );
  177. static const bool is_col = (T1::is_col || T2::is_col );
  178. static const bool is_xvec = (T1::is_xvec || T2::is_xvec);
  179. };
  180. };
  181. template<const bool, const bool, const bool, const bool> class gemm;
  182. template<const bool, const bool, const bool> class gemv;
  183. template< typename eT, typename gen_type> class Gen;
  184. template< typename T1, typename op_type> class Op;
  185. template< typename T1, typename eop_type> class eOp;
  186. template< typename T1, typename op_type> class SpToDOp;
  187. template< typename T1, typename op_type> class CubeToMatOp;
  188. template<typename out_eT, typename T1, typename op_type> class mtOp;
  189. template< typename T1, typename T2, typename glue_type> class Glue;
  190. template< typename T1, typename T2, typename eglue_type> class eGlue;
  191. template<typename out_eT, typename T1, typename T2, typename glue_type> class mtGlue;
  192. template< typename eT, typename gen_type> class GenCube;
  193. template< typename T1, typename op_type> class OpCube;
  194. template< typename T1, typename eop_type> class eOpCube;
  195. template<typename out_eT, typename T1, typename op_type> class mtOpCube;
  196. template< typename T1, typename T2, typename glue_type> class GlueCube;
  197. template< typename T1, typename T2, typename eglue_type> class eGlueCube;
  198. template<typename out_eT, typename T1, typename T2, typename glue_type> class mtGlueCube;
  199. template<typename T1> class Proxy;
  200. template<typename T1> class ProxyCube;
  201. template<typename T1> class diagmat_proxy;
  202. template<typename T1> struct unwrap;
  203. template<typename T1> struct unwrap_cube;
  204. template<typename T1> struct unwrap_spmat;
  205. struct state_type
  206. {
  207. #if defined(ARMA_USE_OPENMP)
  208. int state;
  209. #elif (defined(ARMA_USE_CXX11) && !defined(ARMA_DONT_USE_CXX11_MUTEX))
  210. std::atomic<int> state;
  211. #else
  212. int state;
  213. #endif
  214. arma_inline state_type() : state(int(0)) {}
  215. // openmp: "omp atomic" does an implicit flush on the affected variable
  216. // C++11: std::atomic<>::load() and std::atomic<>::store() use std::memory_order_seq_cst by default, which has an implied fence
  217. arma_inline
  218. operator int () const
  219. {
  220. int out;
  221. #if defined(ARMA_USE_OPENMP)
  222. #pragma omp atomic read
  223. out = state;
  224. #elif (defined(ARMA_USE_CXX11) && !defined(ARMA_DONT_USE_CXX11_MUTEX))
  225. out = state.load();
  226. #else
  227. out = state;
  228. #endif
  229. return out;
  230. }
  231. arma_inline
  232. void
  233. operator= (const int in_state)
  234. {
  235. #if defined(ARMA_USE_OPENMP)
  236. #pragma omp atomic write
  237. state = in_state;
  238. #elif (defined(ARMA_USE_CXX11) && !defined(ARMA_DONT_USE_CXX11_MUTEX))
  239. state.store(in_state);
  240. #else
  241. state = in_state;
  242. #endif
  243. }
  244. };
  245. template< typename T1, typename spop_type> class SpOp;
  246. template<typename out_eT, typename T1, typename spop_type> class mtSpOp;
  247. template< typename T1, typename T2, typename spglue_type> class SpGlue;
  248. template<typename out_eT, typename T1, typename T2, typename spglue_type> class mtSpGlue;
  249. template<typename T1> class SpProxy;
  250. struct arma_vec_indicator {};
  251. struct arma_fixed_indicator {};
  252. struct arma_reserve_indicator {};
  253. struct arma_layout_indicator {};
  254. //! \addtogroup injector
  255. //! @{
  256. template<typename Dummy = int> struct injector_end_of_row {};
  257. static const injector_end_of_row<> endr = injector_end_of_row<>();
  258. //!< endr indicates "end of row" when using the << operator;
  259. //!< similar conceptual meaning to std::endl
  260. //! @}
  261. //! \addtogroup diskio
  262. //! @{
  263. enum file_type
  264. {
  265. file_type_unknown,
  266. auto_detect, //!< attempt to automatically detect the file type
  267. raw_ascii, //!< raw text (ASCII), without a header
  268. arma_ascii, //!< Armadillo text format, with a header specifying matrix type and size
  269. csv_ascii, //!< comma separated values (CSV), without a header
  270. raw_binary, //!< raw binary format (machine dependent), without a header
  271. arma_binary, //!< Armadillo binary format (machine dependent), with a header specifying matrix type and size
  272. pgm_binary, //!< Portable Grey Map (greyscale image)
  273. ppm_binary, //!< Portable Pixel Map (colour image), used by the field and cube classes
  274. hdf5_binary, //!< HDF5: open binary format, not specific to Armadillo, which can store arbitrary data
  275. hdf5_binary_trans, //!< [DO NOT USE - deprecated] as per hdf5_binary, but save/load the data with columns transposed to rows
  276. coord_ascii //!< simple co-ordinate format for sparse matrices (indices start at zero)
  277. };
  278. struct hdf5_name;
  279. struct csv_name;
  280. //! @}
  281. //! \addtogroup fill
  282. //! @{
  283. namespace fill
  284. {
  285. struct fill_none {};
  286. struct fill_zeros {};
  287. struct fill_ones {};
  288. struct fill_eye {};
  289. struct fill_randu {};
  290. struct fill_randn {};
  291. template<typename fill_type>
  292. struct fill_class { inline fill_class() {} };
  293. static const fill_class<fill_none > none;
  294. static const fill_class<fill_zeros> zeros;
  295. static const fill_class<fill_ones > ones;
  296. static const fill_class<fill_eye > eye;
  297. static const fill_class<fill_randu> randu;
  298. static const fill_class<fill_randn> randn;
  299. }
  300. //! @}
  301. //! \addtogroup fn_spsolve
  302. //! @{
  303. struct spsolve_opts_base
  304. {
  305. const unsigned int id;
  306. inline spsolve_opts_base(const unsigned int in_id) : id(in_id) {}
  307. };
  308. struct spsolve_opts_none : public spsolve_opts_base
  309. {
  310. inline spsolve_opts_none() : spsolve_opts_base(0) {}
  311. };
  312. struct superlu_opts : public spsolve_opts_base
  313. {
  314. typedef enum {NATURAL, MMD_ATA, MMD_AT_PLUS_A, COLAMD} permutation_type;
  315. typedef enum {REF_NONE, REF_SINGLE, REF_DOUBLE, REF_EXTRA} refine_type;
  316. bool allow_ugly;
  317. bool equilibrate;
  318. bool symmetric;
  319. double pivot_thresh;
  320. permutation_type permutation;
  321. refine_type refine;
  322. inline superlu_opts()
  323. : spsolve_opts_base(1)
  324. {
  325. allow_ugly = false;
  326. equilibrate = false;
  327. symmetric = false;
  328. pivot_thresh = 1.0;
  329. permutation = COLAMD;
  330. refine = REF_NONE;
  331. }
  332. };
  333. //! @}