fn_as_scalar.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup fn_as_scalar
  16. //! @{
  17. template<uword N>
  18. struct as_scalar_redirect
  19. {
  20. template<typename T1>
  21. inline static typename T1::elem_type apply(const T1& X);
  22. };
  23. template<>
  24. struct as_scalar_redirect<2>
  25. {
  26. template<typename T1, typename T2>
  27. inline static typename T1::elem_type apply(const Glue<T1,T2,glue_times>& X);
  28. };
  29. template<>
  30. struct as_scalar_redirect<3>
  31. {
  32. template<typename T1, typename T2, typename T3>
  33. inline static typename T1::elem_type apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times>& X);
  34. };
  35. template<uword N>
  36. template<typename T1>
  37. inline
  38. typename T1::elem_type
  39. as_scalar_redirect<N>::apply(const T1& X)
  40. {
  41. arma_extra_debug_sigprint();
  42. typedef typename T1::elem_type eT;
  43. const Proxy<T1> P(X);
  44. if(P.get_n_elem() != 1)
  45. {
  46. arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element");
  47. return Datum<eT>::nan;
  48. }
  49. return (Proxy<T1>::use_at) ? P.at(0,0) : P[0];
  50. }
  51. template<typename T1, typename T2>
  52. inline
  53. typename T1::elem_type
  54. as_scalar_redirect<2>::apply(const Glue<T1, T2, glue_times>& X)
  55. {
  56. arma_extra_debug_sigprint();
  57. typedef typename T1::elem_type eT;
  58. // T1 must result in a matrix with one row
  59. // T2 must result in a matrix with one column
  60. const bool has_all_mat = (is_Mat<T1>::value || is_Mat_trans<T1>::value) && (is_Mat<T2>::value || is_Mat_trans<T2>::value);
  61. const bool use_at = (Proxy<T1>::use_at || Proxy<T2>::use_at);
  62. const bool do_partial_unwrap = (has_all_mat || use_at);
  63. if(do_partial_unwrap)
  64. {
  65. const partial_unwrap<T1> tmp1(X.A);
  66. const partial_unwrap<T2> tmp2(X.B);
  67. typedef typename partial_unwrap<T1>::stored_type TA;
  68. typedef typename partial_unwrap<T2>::stored_type TB;
  69. const TA& A = tmp1.M;
  70. const TB& B = tmp2.M;
  71. const uword A_n_rows = (tmp1.do_trans == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols);
  72. const uword A_n_cols = (tmp1.do_trans == false) ? (TA::is_col ? 1 : A.n_cols) : (TA::is_row ? 1 : A.n_rows);
  73. const uword B_n_rows = (tmp2.do_trans == false) ? (TB::is_row ? 1 : B.n_rows) : (TB::is_col ? 1 : B.n_cols);
  74. const uword B_n_cols = (tmp2.do_trans == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows);
  75. arma_debug_check( (A_n_rows != 1) || (B_n_cols != 1) || (A_n_cols != B_n_rows), "as_scalar(): incompatible dimensions" );
  76. const eT val = op_dot::direct_dot(A.n_elem, A.memptr(), B.memptr());
  77. return (tmp1.do_times || tmp2.do_times) ? (val * tmp1.get_val() * tmp2.get_val()) : val;
  78. }
  79. else
  80. {
  81. const Proxy<T1> PA(X.A);
  82. const Proxy<T2> PB(X.B);
  83. arma_debug_check
  84. (
  85. (PA.get_n_rows() != 1) || (PB.get_n_cols() != 1) || (PA.get_n_cols() != PB.get_n_rows()),
  86. "as_scalar(): incompatible dimensions"
  87. );
  88. return op_dot::apply_proxy(PA,PB);
  89. }
  90. }
  91. template<typename T1, typename T2, typename T3>
  92. inline
  93. typename T1::elem_type
  94. as_scalar_redirect<3>::apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times >& X)
  95. {
  96. arma_extra_debug_sigprint();
  97. typedef typename T1::elem_type eT;
  98. // T1 * T2 must result in a matrix with one row
  99. // T3 must result in a matrix with one column
  100. typedef typename strip_inv <T2 >::stored_type T2_stripped_1;
  101. typedef typename strip_diagmat<T2_stripped_1>::stored_type T2_stripped_2;
  102. const strip_inv <T2> strip1(X.A.B);
  103. const strip_diagmat<T2_stripped_1> strip2(strip1.M);
  104. const bool tmp2_do_inv = strip1.do_inv;
  105. const bool tmp2_do_diagmat = strip2.do_diagmat;
  106. if(tmp2_do_diagmat == false)
  107. {
  108. const Mat<eT> tmp(X);
  109. if(tmp.n_elem != 1)
  110. {
  111. arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element");
  112. return Datum<eT>::nan;
  113. }
  114. return tmp[0];
  115. }
  116. else
  117. {
  118. const partial_unwrap<T1> tmp1(X.A.A);
  119. const partial_unwrap<T2_stripped_2> tmp2(strip2.M);
  120. const partial_unwrap<T3> tmp3(X.B);
  121. const Mat<eT>& A = tmp1.M;
  122. const Mat<eT>& B = tmp2.M;
  123. const Mat<eT>& C = tmp3.M;
  124. const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
  125. const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
  126. const bool B_is_vec = B.is_vec();
  127. const uword B_n_rows = (B_is_vec) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
  128. const uword B_n_cols = (B_is_vec) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
  129. const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
  130. const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
  131. const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val();
  132. arma_debug_check
  133. (
  134. (A_n_rows != 1) ||
  135. (C_n_cols != 1) ||
  136. (A_n_cols != B_n_rows) ||
  137. (B_n_cols != C_n_rows)
  138. ,
  139. "as_scalar(): incompatible dimensions"
  140. );
  141. if(B_is_vec)
  142. {
  143. if(tmp2_do_inv)
  144. {
  145. return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem);
  146. }
  147. else
  148. {
  149. return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
  150. }
  151. }
  152. else
  153. {
  154. if(tmp2_do_inv)
  155. {
  156. return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem);
  157. }
  158. else
  159. {
  160. return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
  161. }
  162. }
  163. }
  164. }
  165. template<typename T1>
  166. inline
  167. typename T1::elem_type
  168. as_scalar_diag(const Base<typename T1::elem_type,T1>& X)
  169. {
  170. arma_extra_debug_sigprint();
  171. typedef typename T1::elem_type eT;
  172. const unwrap<T1> tmp(X.get_ref());
  173. const Mat<eT>& A = tmp.M;
  174. if(A.n_elem != 1)
  175. {
  176. arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element");
  177. return Datum<eT>::nan;
  178. }
  179. return A.mem[0];
  180. }
  181. template<typename T1, typename T2, typename T3>
  182. inline
  183. typename T1::elem_type
  184. as_scalar_diag(const Glue< Glue<T1, T2, glue_times_diag>, T3, glue_times >& X)
  185. {
  186. arma_extra_debug_sigprint();
  187. typedef typename T1::elem_type eT;
  188. // T1 * T2 must result in a matrix with one row
  189. // T3 must result in a matrix with one column
  190. typedef typename strip_diagmat<T2>::stored_type T2_stripped;
  191. const strip_diagmat<T2> strip(X.A.B);
  192. const partial_unwrap<T1> tmp1(X.A.A);
  193. const partial_unwrap<T2_stripped> tmp2(strip.M);
  194. const partial_unwrap<T3> tmp3(X.B);
  195. const Mat<eT>& A = tmp1.M;
  196. const Mat<eT>& B = tmp2.M;
  197. const Mat<eT>& C = tmp3.M;
  198. const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
  199. const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
  200. const bool B_is_vec = B.is_vec();
  201. const uword B_n_rows = (B_is_vec) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
  202. const uword B_n_cols = (B_is_vec) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
  203. const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
  204. const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
  205. const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val();
  206. arma_debug_check
  207. (
  208. (A_n_rows != 1) ||
  209. (C_n_cols != 1) ||
  210. (A_n_cols != B_n_rows) ||
  211. (B_n_cols != C_n_rows)
  212. ,
  213. "as_scalar(): incompatible dimensions"
  214. );
  215. if(B_is_vec)
  216. {
  217. return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
  218. }
  219. else
  220. {
  221. return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
  222. }
  223. }
  224. template<typename T1, typename T2>
  225. arma_warn_unused
  226. arma_inline
  227. typename T1::elem_type
  228. as_scalar(const Glue<T1, T2, glue_times>& X, const typename arma_not_cx<typename T1::elem_type>::result* junk = 0)
  229. {
  230. arma_extra_debug_sigprint();
  231. arma_ignore(junk);
  232. if(is_glue_times_diag<T1>::value == false)
  233. {
  234. const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
  235. arma_extra_debug_print(arma_str::format("N_mat = %d") % N_mat);
  236. return as_scalar_redirect<N_mat>::apply(X);
  237. }
  238. else
  239. {
  240. return as_scalar_diag(X);
  241. }
  242. }
  243. template<typename T1>
  244. arma_warn_unused
  245. inline
  246. typename T1::elem_type
  247. as_scalar(const Base<typename T1::elem_type,T1>& X)
  248. {
  249. arma_extra_debug_sigprint();
  250. typedef typename T1::elem_type eT;
  251. const Proxy<T1> P(X.get_ref());
  252. if(P.get_n_elem() != 1)
  253. {
  254. arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element");
  255. return Datum<eT>::nan;
  256. }
  257. return (Proxy<T1>::use_at) ? P.at(0,0) : P[0];
  258. }
  259. template<typename T1>
  260. arma_warn_unused
  261. inline
  262. typename T1::elem_type
  263. as_scalar(const Gen<T1, gen_randu>& X)
  264. {
  265. arma_extra_debug_sigprint();
  266. typedef typename T1::elem_type eT;
  267. if( (X.n_rows != 1) || (X.n_cols != 1) )
  268. {
  269. arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element");
  270. return Datum<eT>::nan;
  271. }
  272. return eT(arma_rng::randu<eT>());
  273. }
  274. template<typename T1>
  275. arma_warn_unused
  276. inline
  277. typename T1::elem_type
  278. as_scalar(const Gen<T1, gen_randn>& X)
  279. {
  280. arma_extra_debug_sigprint();
  281. typedef typename T1::elem_type eT;
  282. if( (X.n_rows != 1) || (X.n_cols != 1) )
  283. {
  284. arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element");
  285. return Datum<eT>::nan;
  286. }
  287. return eT(arma_rng::randn<eT>());
  288. }
  289. template<typename T1>
  290. arma_warn_unused
  291. inline
  292. typename T1::elem_type
  293. as_scalar(const BaseCube<typename T1::elem_type,T1>& X)
  294. {
  295. arma_extra_debug_sigprint();
  296. typedef typename T1::elem_type eT;
  297. const ProxyCube<T1> P(X.get_ref());
  298. if(P.get_n_elem() != 1)
  299. {
  300. arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element");
  301. return Datum<eT>::nan;
  302. }
  303. return (ProxyCube<T1>::use_at) ? P.at(0,0,0) : P[0];
  304. }
  305. template<typename T>
  306. arma_warn_unused
  307. arma_inline
  308. typename arma_scalar_only<T>::result
  309. as_scalar(const T& x)
  310. {
  311. return x;
  312. }
  313. template<typename T1>
  314. arma_warn_unused
  315. inline
  316. typename T1::elem_type
  317. as_scalar(const SpBase<typename T1::elem_type, T1>& X)
  318. {
  319. typedef typename T1::elem_type eT;
  320. const unwrap_spmat<T1> tmp(X.get_ref());
  321. const SpMat<eT>& A = tmp.M;
  322. if(A.n_elem != 1)
  323. {
  324. arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element");
  325. return Datum<eT>::nan;
  326. }
  327. return A.at(0,0);
  328. }
  329. //! @}