diagmat_proxy.hpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup diagmat_proxy
  16. //! @{
  17. template<typename T1>
  18. class diagmat_proxy_default
  19. {
  20. public:
  21. typedef typename T1::elem_type elem_type;
  22. typedef typename get_pod_type<elem_type>::result pod_type;
  23. inline
  24. diagmat_proxy_default(const T1& X)
  25. : P ( X )
  26. , P_is_vec( (resolves_to_vector<T1>::yes) || (P.get_n_rows() == 1) || (P.get_n_cols() == 1) )
  27. , P_is_col( T1::is_col || (P.get_n_cols() == 1) )
  28. , n_rows ( P_is_vec ? P.get_n_elem() : P.get_n_rows() )
  29. , n_cols ( P_is_vec ? P.get_n_elem() : P.get_n_cols() )
  30. {
  31. arma_extra_debug_sigprint();
  32. }
  33. arma_inline
  34. elem_type
  35. operator[](const uword i) const
  36. {
  37. if(Proxy<T1>::use_at == false)
  38. {
  39. return P_is_vec ? P[i] : P.at(i,i);
  40. }
  41. else
  42. {
  43. if(P_is_vec)
  44. {
  45. return (P_is_col) ? P.at(i,0) : P.at(0,i);
  46. }
  47. else
  48. {
  49. return P.at(i,i);
  50. }
  51. }
  52. }
  53. arma_inline
  54. elem_type
  55. at(const uword row, const uword col) const
  56. {
  57. if(row == col)
  58. {
  59. if(Proxy<T1>::use_at == false)
  60. {
  61. return (P_is_vec) ? P[row] : P.at(row,row);
  62. }
  63. else
  64. {
  65. if(P_is_vec)
  66. {
  67. return (P_is_col) ? P.at(row,0) : P.at(0,row);
  68. }
  69. else
  70. {
  71. return P.at(row,row);
  72. }
  73. }
  74. }
  75. else
  76. {
  77. return elem_type(0);
  78. }
  79. }
  80. inline bool is_alias(const Mat<elem_type>& X) const { return P.is_alias(X); }
  81. const Proxy<T1> P;
  82. const bool P_is_vec;
  83. const bool P_is_col;
  84. const uword n_rows;
  85. const uword n_cols;
  86. };
  87. template<typename T1>
  88. class diagmat_proxy_fixed
  89. {
  90. public:
  91. typedef typename T1::elem_type elem_type;
  92. typedef typename get_pod_type<elem_type>::result pod_type;
  93. inline
  94. diagmat_proxy_fixed(const T1& X)
  95. : P(X)
  96. {
  97. arma_extra_debug_sigprint();
  98. }
  99. arma_inline
  100. elem_type
  101. operator[](const uword i) const
  102. {
  103. return (P_is_vec) ? P[i] : P.at(i,i);
  104. }
  105. arma_inline
  106. elem_type
  107. at(const uword row, const uword col) const
  108. {
  109. if(row == col)
  110. {
  111. return (P_is_vec) ? P[row] : P.at(row,row);
  112. }
  113. else
  114. {
  115. return elem_type(0);
  116. }
  117. }
  118. arma_inline bool is_alias(const Mat<elem_type>& X) const { return (void_ptr(&X) == void_ptr(&P)); }
  119. const T1& P;
  120. static const bool P_is_vec = (T1::n_rows == 1) || (T1::n_cols == 1);
  121. static const uword n_rows = P_is_vec ? T1::n_elem : T1::n_rows;
  122. static const uword n_cols = P_is_vec ? T1::n_elem : T1::n_cols;
  123. };
  124. template<typename T1, bool condition>
  125. struct diagmat_proxy_redirect {};
  126. template<typename T1>
  127. struct diagmat_proxy_redirect<T1, false> { typedef diagmat_proxy_default<T1> result; };
  128. template<typename T1>
  129. struct diagmat_proxy_redirect<T1, true> { typedef diagmat_proxy_fixed<T1> result; };
  130. template<typename T1>
  131. class diagmat_proxy : public diagmat_proxy_redirect<T1, is_Mat_fixed<T1>::value >::result
  132. {
  133. public:
  134. inline diagmat_proxy(const T1& X)
  135. : diagmat_proxy_redirect< T1, is_Mat_fixed<T1>::value >::result(X)
  136. {
  137. }
  138. };
  139. template<typename eT>
  140. class diagmat_proxy< Mat<eT> >
  141. {
  142. public:
  143. typedef eT elem_type;
  144. typedef typename get_pod_type<elem_type>::result pod_type;
  145. inline
  146. diagmat_proxy(const Mat<eT>& X)
  147. : P ( X )
  148. , P_is_vec( (X.n_rows == 1) || (X.n_cols == 1) )
  149. , n_rows ( P_is_vec ? X.n_elem : X.n_rows )
  150. , n_cols ( P_is_vec ? X.n_elem : X.n_cols )
  151. {
  152. arma_extra_debug_sigprint();
  153. }
  154. arma_inline elem_type operator[] (const uword i) const { return P_is_vec ? P[i] : P.at(i,i); }
  155. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? ( P_is_vec ? P[row] : P.at(row,row) ) : elem_type(0); }
  156. arma_inline bool is_alias(const Mat<eT>& X) const { return (void_ptr(&X) == void_ptr(&P)); }
  157. const Mat<eT>& P;
  158. const bool P_is_vec;
  159. const uword n_rows;
  160. const uword n_cols;
  161. };
  162. template<typename eT>
  163. class diagmat_proxy< Row<eT> >
  164. {
  165. public:
  166. typedef eT elem_type;
  167. typedef typename get_pod_type<elem_type>::result pod_type;
  168. inline
  169. diagmat_proxy(const Row<eT>& X)
  170. : P(X)
  171. , n_rows(X.n_elem)
  172. , n_cols(X.n_elem)
  173. {
  174. arma_extra_debug_sigprint();
  175. }
  176. arma_inline elem_type operator[] (const uword i) const { return P[i]; }
  177. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); }
  178. arma_inline bool is_alias(const Mat<eT>& X) const { return (void_ptr(&X) == void_ptr(&P)); }
  179. static const bool P_is_vec = true;
  180. const Row<eT>& P;
  181. const uword n_rows;
  182. const uword n_cols;
  183. };
  184. template<typename eT>
  185. class diagmat_proxy< Col<eT> >
  186. {
  187. public:
  188. typedef eT elem_type;
  189. typedef typename get_pod_type<elem_type>::result pod_type;
  190. inline
  191. diagmat_proxy(const Col<eT>& X)
  192. : P(X)
  193. , n_rows(X.n_elem)
  194. , n_cols(X.n_elem)
  195. {
  196. arma_extra_debug_sigprint();
  197. }
  198. arma_inline elem_type operator[] (const uword i) const { return P[i]; }
  199. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); }
  200. arma_inline bool is_alias(const Mat<eT>& X) const { return (void_ptr(&X) == void_ptr(&P)); }
  201. static const bool P_is_vec = true;
  202. const Col<eT>& P;
  203. const uword n_rows;
  204. const uword n_cols;
  205. };
  206. template<typename eT>
  207. class diagmat_proxy< subview_row<eT> >
  208. {
  209. public:
  210. typedef eT elem_type;
  211. typedef typename get_pod_type<elem_type>::result pod_type;
  212. inline
  213. diagmat_proxy(const subview_row<eT>& X)
  214. : P(X)
  215. , n_rows(X.n_elem)
  216. , n_cols(X.n_elem)
  217. {
  218. arma_extra_debug_sigprint();
  219. }
  220. arma_inline elem_type operator[] (const uword i) const { return P[i]; }
  221. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); }
  222. arma_inline bool is_alias(const Mat<eT>& X) const { return (void_ptr(&X) == void_ptr(&(P.m))); }
  223. static const bool P_is_vec = true;
  224. const subview_row<eT>& P;
  225. const uword n_rows;
  226. const uword n_cols;
  227. };
  228. template<typename eT>
  229. class diagmat_proxy< subview_col<eT> >
  230. {
  231. public:
  232. typedef eT elem_type;
  233. typedef typename get_pod_type<elem_type>::result pod_type;
  234. inline
  235. diagmat_proxy(const subview_col<eT>& X)
  236. : P(X)
  237. , n_rows(X.n_elem)
  238. , n_cols(X.n_elem)
  239. {
  240. arma_extra_debug_sigprint();
  241. }
  242. arma_inline elem_type operator[] (const uword i) const { return P[i]; }
  243. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); }
  244. arma_inline bool is_alias(const Mat<eT>& X) const { return (void_ptr(&X) == void_ptr(&(P.m))); }
  245. static const bool P_is_vec = true;
  246. const subview_col<eT>& P;
  247. const uword n_rows;
  248. const uword n_cols;
  249. };
  250. template<typename T1, typename T2>
  251. class diagmat_proxy< Glue<T1,T2,glue_times> >
  252. {
  253. public:
  254. typedef typename T1::elem_type elem_type;
  255. typedef typename get_pod_type<elem_type>::result pod_type;
  256. inline
  257. diagmat_proxy(const Glue<T1,T2,glue_times>& X)
  258. {
  259. op_diagmat::apply_times(P, X.A, X.B);
  260. n_rows = P.n_rows;
  261. n_cols = P.n_cols;
  262. arma_extra_debug_sigprint();
  263. }
  264. arma_inline elem_type operator[] (const uword i) const { return P.at(i,i); }
  265. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P.at(row,row) : elem_type(0); }
  266. arma_inline bool is_alias(const Mat<elem_type>&) const { return false; }
  267. static const bool P_is_vec = false;
  268. Mat<elem_type> P;
  269. uword n_rows;
  270. uword n_cols;
  271. };
  272. //
  273. //
  274. //
  275. template<typename T1>
  276. class diagmat_proxy_check_default
  277. {
  278. public:
  279. typedef typename T1::elem_type elem_type;
  280. typedef typename get_pod_type<elem_type>::result pod_type;
  281. inline
  282. diagmat_proxy_check_default(const T1& X, const Mat<typename T1::elem_type>&)
  283. : P(X)
  284. , P_is_vec( (resolves_to_vector<T1>::yes) || (P.n_rows == 1) || (P.n_cols == 1) )
  285. , n_rows( P_is_vec ? P.n_elem : P.n_rows )
  286. , n_cols( P_is_vec ? P.n_elem : P.n_cols )
  287. {
  288. arma_extra_debug_sigprint();
  289. }
  290. arma_inline elem_type operator[] (const uword i) const { return P_is_vec ? P[i] : P.at(i,i); }
  291. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? ( P_is_vec ? P[row] : P.at(row,row) ) : elem_type(0); }
  292. const Mat<elem_type> P;
  293. const bool P_is_vec;
  294. const uword n_rows;
  295. const uword n_cols;
  296. };
  297. template<typename T1>
  298. class diagmat_proxy_check_fixed
  299. {
  300. public:
  301. typedef typename T1::elem_type eT;
  302. typedef typename T1::elem_type elem_type;
  303. typedef typename get_pod_type<elem_type>::result pod_type;
  304. inline
  305. diagmat_proxy_check_fixed(const T1& X, const Mat<eT>& out)
  306. : P( const_cast<eT*>(X.memptr()), T1::n_rows, T1::n_cols, (&X == &out), false )
  307. {
  308. arma_extra_debug_sigprint();
  309. }
  310. arma_inline eT operator[] (const uword i) const { return P_is_vec ? P[i] : P.at(i,i); }
  311. arma_inline eT at (const uword row, const uword col) const { return (row == col) ? ( P_is_vec ? P[row] : P.at(row,row) ) : elem_type(0); }
  312. const Mat<eT> P; // TODO: why not just store X directly as T1& ? test with fixed size vectors and matrices
  313. static const bool P_is_vec = (T1::n_rows == 1) || (T1::n_cols == 1);
  314. static const uword n_rows = P_is_vec ? T1::n_elem : T1::n_rows;
  315. static const uword n_cols = P_is_vec ? T1::n_elem : T1::n_cols;
  316. };
  317. template<typename T1, bool condition>
  318. struct diagmat_proxy_check_redirect {};
  319. template<typename T1>
  320. struct diagmat_proxy_check_redirect<T1, false> { typedef diagmat_proxy_check_default<T1> result; };
  321. template<typename T1>
  322. struct diagmat_proxy_check_redirect<T1, true> { typedef diagmat_proxy_check_fixed<T1> result; };
  323. template<typename T1>
  324. class diagmat_proxy_check : public diagmat_proxy_check_redirect<T1, is_Mat_fixed<T1>::value >::result
  325. {
  326. public:
  327. inline diagmat_proxy_check(const T1& X, const Mat<typename T1::elem_type>& out)
  328. : diagmat_proxy_check_redirect< T1, is_Mat_fixed<T1>::value >::result(X, out)
  329. {
  330. }
  331. };
  332. template<typename eT>
  333. class diagmat_proxy_check< Mat<eT> >
  334. {
  335. public:
  336. typedef eT elem_type;
  337. typedef typename get_pod_type<elem_type>::result pod_type;
  338. inline
  339. diagmat_proxy_check(const Mat<eT>& X, const Mat<eT>& out)
  340. : P_local ( (&X == &out) ? new Mat<eT>(X) : 0 )
  341. , P ( (&X == &out) ? (*P_local) : X )
  342. , P_is_vec( (P.n_rows == 1) || (P.n_cols == 1) )
  343. , n_rows ( P_is_vec ? P.n_elem : P.n_rows )
  344. , n_cols ( P_is_vec ? P.n_elem : P.n_cols )
  345. {
  346. arma_extra_debug_sigprint();
  347. }
  348. inline ~diagmat_proxy_check()
  349. {
  350. if(P_local) { delete P_local; }
  351. }
  352. arma_inline elem_type operator[] (const uword i) const { return P_is_vec ? P[i] : P.at(i,i); }
  353. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? ( P_is_vec ? P[row] : P.at(row,row) ) : elem_type(0); }
  354. const Mat<eT>* P_local;
  355. const Mat<eT>& P;
  356. const bool P_is_vec;
  357. const uword n_rows;
  358. const uword n_cols;
  359. };
  360. template<typename eT>
  361. class diagmat_proxy_check< Row<eT> >
  362. {
  363. public:
  364. typedef eT elem_type;
  365. typedef typename get_pod_type<elem_type>::result pod_type;
  366. inline
  367. diagmat_proxy_check(const Row<eT>& X, const Mat<eT>& out)
  368. : P_local ( (&X == reinterpret_cast<const Row<eT>*>(&out)) ? new Row<eT>(X) : 0 )
  369. , P ( (&X == reinterpret_cast<const Row<eT>*>(&out)) ? (*P_local) : X )
  370. , n_rows (X.n_elem)
  371. , n_cols (X.n_elem)
  372. {
  373. arma_extra_debug_sigprint();
  374. }
  375. inline ~diagmat_proxy_check()
  376. {
  377. if(P_local) { delete P_local; }
  378. }
  379. arma_inline elem_type operator[] (const uword i) const { return P[i]; }
  380. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); }
  381. static const bool P_is_vec = true;
  382. const Row<eT>* P_local;
  383. const Row<eT>& P;
  384. const uword n_rows;
  385. const uword n_cols;
  386. };
  387. template<typename eT>
  388. class diagmat_proxy_check< Col<eT> >
  389. {
  390. public:
  391. typedef eT elem_type;
  392. typedef typename get_pod_type<elem_type>::result pod_type;
  393. inline
  394. diagmat_proxy_check(const Col<eT>& X, const Mat<eT>& out)
  395. : P_local ( (&X == reinterpret_cast<const Col<eT>*>(&out)) ? new Col<eT>(X) : 0 )
  396. , P ( (&X == reinterpret_cast<const Col<eT>*>(&out)) ? (*P_local) : X )
  397. , n_rows (X.n_elem)
  398. , n_cols (X.n_elem)
  399. {
  400. arma_extra_debug_sigprint();
  401. }
  402. inline ~diagmat_proxy_check()
  403. {
  404. if(P_local) { delete P_local; }
  405. }
  406. arma_inline elem_type operator[] (const uword i) const { return P[i]; }
  407. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); }
  408. static const bool P_is_vec = true;
  409. const Col<eT>* P_local;
  410. const Col<eT>& P;
  411. const uword n_rows;
  412. const uword n_cols;
  413. };
  414. template<typename eT>
  415. class diagmat_proxy_check< subview_row<eT> >
  416. {
  417. public:
  418. typedef eT elem_type;
  419. typedef typename get_pod_type<elem_type>::result pod_type;
  420. inline
  421. diagmat_proxy_check(const subview_row<eT>& X, const Mat<eT>&)
  422. : P ( X )
  423. , n_rows ( X.n_elem )
  424. , n_cols ( X.n_elem )
  425. {
  426. arma_extra_debug_sigprint();
  427. }
  428. arma_inline elem_type operator[] (const uword i) const { return P[i]; }
  429. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); }
  430. static const bool P_is_vec = true;
  431. const Row<eT> P;
  432. const uword n_rows;
  433. const uword n_cols;
  434. };
  435. template<typename eT>
  436. class diagmat_proxy_check< subview_col<eT> >
  437. {
  438. public:
  439. typedef eT elem_type;
  440. typedef typename get_pod_type<elem_type>::result pod_type;
  441. inline
  442. diagmat_proxy_check(const subview_col<eT>& X, const Mat<eT>& out)
  443. : P ( const_cast<eT*>(X.colptr(0)), X.n_rows, (&(X.m) == &out), false )
  444. , n_rows( X.n_elem )
  445. , n_cols( X.n_elem )
  446. {
  447. arma_extra_debug_sigprint();
  448. }
  449. arma_inline elem_type operator[] (const uword i) const { return P[i]; }
  450. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); }
  451. static const bool P_is_vec = true;
  452. const Col<eT> P;
  453. const uword n_rows;
  454. const uword n_cols;
  455. };
  456. template<typename T1, typename T2>
  457. class diagmat_proxy_check< Glue<T1,T2,glue_times> >
  458. {
  459. public:
  460. typedef typename T1::elem_type elem_type;
  461. typedef typename get_pod_type<elem_type>::result pod_type;
  462. inline
  463. diagmat_proxy_check(const Glue<T1,T2,glue_times>& X, const Mat<elem_type>&)
  464. {
  465. op_diagmat::apply_times(P, X.A, X.B);
  466. n_rows = P.n_rows;
  467. n_cols = P.n_cols;
  468. arma_extra_debug_sigprint();
  469. }
  470. arma_inline elem_type operator[] (const uword i) const { return P.at(i,i); }
  471. arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P.at(row,row) : elem_type(0); }
  472. static const bool P_is_vec = false;
  473. Mat<elem_type> P;
  474. uword n_rows;
  475. uword n_cols;
  476. };
  477. //! @}