123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989 |
- // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
- // Copyright 2008-2016 National ICT Australia (NICTA)
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- // ------------------------------------------------------------------------
- //! \addtogroup glue_times
- //! @{
- template<bool do_inv_detect>
- template<typename T1, typename T2>
- arma_hot
- inline
- void
- glue_times_redirect2_helper<do_inv_detect>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
- {
- arma_extra_debug_sigprint();
-
- typedef typename T1::elem_type eT;
-
- const partial_unwrap<T1> tmp1(X.A);
- const partial_unwrap<T2> tmp2(X.B);
-
- const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
- const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
-
- const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
- const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
-
- const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out);
-
- if(alias == false)
- {
- glue_times::apply
- <
- eT,
- partial_unwrap<T1>::do_trans,
- partial_unwrap<T2>::do_trans,
- (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times)
- >
- (out, A, B, alpha);
- }
- else
- {
- Mat<eT> tmp;
-
- glue_times::apply
- <
- eT,
- partial_unwrap<T1>::do_trans,
- partial_unwrap<T2>::do_trans,
- (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times)
- >
- (tmp, A, B, alpha);
-
- out.steal_mem(tmp);
- }
- }
- template<typename T1, typename T2>
- arma_hot
- inline
- void
- glue_times_redirect2_helper<true>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
- {
- arma_extra_debug_sigprint();
-
- typedef typename T1::elem_type eT;
-
- if(strip_inv<T1>::do_inv == true)
- {
- // replace inv(A)*B with solve(A,B)
-
- arma_extra_debug_print("glue_times_redirect<2>::apply(): detected inv(A)*B");
-
- const strip_inv<T1> A_strip(X.A);
-
- Mat<eT> A = A_strip.M;
-
- arma_debug_check( (A.is_square() == false), "inv(): given matrix must be square sized" );
-
- if(strip_inv<T1>::do_inv_sympd)
- {
- // if(auxlib::rudimentary_sym_check(A) == false)
- // {
- // if(is_cx<eT>::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); }
- // if(is_cx<eT>::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); }
- //
- // out.soft_reset();
- // arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead");
- //
- // return;
- // }
-
- if( (arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false) )
- {
- if(is_cx<eT>::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); }
- if(is_cx<eT>::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); }
- }
- }
-
- const unwrap_check<T2> B_tmp(X.B, out);
- const Mat<eT>& B = B_tmp.M;
-
- arma_debug_assert_mul_size(A, B, "matrix multiplication");
-
- // TODO: detect sympd via sympd_helper::guess_sympd(A) ?
-
- #if defined(ARMA_OPTIMISE_SYMPD)
- const bool status = (strip_inv<T1>::do_inv_sympd) ? auxlib::solve_sympd_fast(out, A, B) : auxlib::solve_square_fast(out, A, B);
- #else
- const bool status = auxlib::solve_square_fast(out, A, B);
- #endif
-
- if(status == false)
- {
- out.soft_reset();
- arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead");
- }
-
- return;
- }
-
- #if defined(ARMA_OPTIMISE_SYMPD)
- {
- if(strip_inv<T2>::do_inv_sympd)
- {
- // replace A*inv_sympd(B) with trans( solve(trans(B),trans(A)) )
- // transpose of B is avoided as B is explicitly marked as symmetric
-
- arma_extra_debug_print("glue_times_redirect<2>::apply(): detected A*inv_sympd(B)");
-
- const Mat<eT> At = trans(X.A);
-
- const strip_inv<T2> B_strip(X.B);
-
- Mat<eT> B = B_strip.M;
-
- arma_debug_check( (B.is_square() == false), "inv_sympd(): given matrix must be square sized" );
-
- // if(auxlib::rudimentary_sym_check(B) == false)
- // {
- // if(is_cx<eT>::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); }
- // if(is_cx<eT>::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); }
- //
- // out.soft_reset();
- // arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead");
- //
- // return;
- // }
-
- if( (arma_config::debug) && (auxlib::rudimentary_sym_check(B) == false) )
- {
- if(is_cx<eT>::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); }
- if(is_cx<eT>::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); }
- }
-
- arma_debug_assert_mul_size(At.n_cols, At.n_rows, B.n_rows, B.n_cols, "matrix multiplication");
-
- const bool status = auxlib::solve_sympd_fast(out, B, At);
-
- if(status == false)
- {
- out.soft_reset();
- arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead");
- }
-
- out = trans(out);
-
- return;
- }
- }
- #endif
-
- glue_times_redirect2_helper<false>::apply(out, X);
- }
- template<bool do_inv_detect>
- template<typename T1, typename T2, typename T3>
- arma_hot
- inline
- void
- glue_times_redirect3_helper<do_inv_detect>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
- {
- arma_extra_debug_sigprint();
-
- typedef typename T1::elem_type eT;
-
- // we have exactly 3 objects
- // hence we can safely expand X as X.A.A, X.A.B and X.B
-
- const partial_unwrap<T1> tmp1(X.A.A);
- const partial_unwrap<T2> tmp2(X.A.B);
- const partial_unwrap<T3> tmp3(X.B );
-
- const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
- const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
- const typename partial_unwrap<T3>::stored_type& C = tmp3.M;
-
- const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times;
- const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0);
-
- const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out) || tmp3.is_alias(out);
-
- if(alias == false)
- {
- glue_times::apply
- <
- eT,
- partial_unwrap<T1>::do_trans,
- partial_unwrap<T2>::do_trans,
- partial_unwrap<T3>::do_trans,
- (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times)
- >
- (out, A, B, C, alpha);
- }
- else
- {
- Mat<eT> tmp;
-
- glue_times::apply
- <
- eT,
- partial_unwrap<T1>::do_trans,
- partial_unwrap<T2>::do_trans,
- partial_unwrap<T3>::do_trans,
- (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times)
- >
- (tmp, A, B, C, alpha);
-
- out.steal_mem(tmp);
- }
- }
- template<typename T1, typename T2, typename T3>
- arma_hot
- inline
- void
- glue_times_redirect3_helper<true>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
- {
- arma_extra_debug_sigprint();
-
- typedef typename T1::elem_type eT;
-
- if(strip_inv<T1>::do_inv == true)
- {
- // replace inv(A)*B*C with solve(A,B*C);
-
- arma_extra_debug_print("glue_times_redirect<3>::apply(): detected inv(A)*B*C");
-
- const strip_inv<T1> A_strip(X.A.A);
-
- Mat<eT> A = A_strip.M;
-
- arma_debug_check( (A.is_square() == false), "inv(): given matrix must be square sized" );
-
- const partial_unwrap<T2> tmp2(X.A.B);
- const partial_unwrap<T3> tmp3(X.B );
-
- const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
- const typename partial_unwrap<T3>::stored_type& C = tmp3.M;
-
- const bool use_alpha = partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times;
- const eT alpha = use_alpha ? (tmp2.get_val() * tmp3.get_val()) : eT(0);
-
- Mat<eT> BC;
-
- glue_times::apply
- <
- eT,
- partial_unwrap<T2>::do_trans,
- partial_unwrap<T3>::do_trans,
- (partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times)
- >
- (BC, B, C, alpha);
-
- arma_debug_assert_mul_size(A, BC, "matrix multiplication");
-
- // TODO: detect sympd via sympd_helper::guess_sympd(A) ?
-
- #if defined(ARMA_OPTIMISE_SYMPD)
- const bool status = (strip_inv<T1>::do_inv_sympd) ? auxlib::solve_sympd_fast(out, A, BC) : auxlib::solve_square_fast(out, A, BC);
- #else
- const bool status = auxlib::solve_square_fast(out, A, BC);
- #endif
-
- if(status == false)
- {
- out.soft_reset();
- arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead");
- }
-
- return;
- }
-
-
- if(strip_inv<T2>::do_inv == true)
- {
- // replace A*inv(B)*C with A*solve(B,C)
-
- arma_extra_debug_print("glue_times_redirect<3>::apply(): detected A*inv(B)*C");
-
- const strip_inv<T2> B_strip(X.A.B);
-
- Mat<eT> B = B_strip.M;
-
- arma_debug_check( (B.is_square() == false), "inv(): given matrix must be square sized" );
-
- const unwrap<T3> C_tmp(X.B);
- const Mat<eT>& C = C_tmp.M;
-
- arma_debug_assert_mul_size(B, C, "matrix multiplication");
-
- Mat<eT> solve_result;
-
- #if defined(ARMA_OPTIMISE_SYMPD)
- const bool status = (strip_inv<T2>::do_inv_sympd) ? auxlib::solve_sympd_fast(solve_result, B, C) : auxlib::solve_square_fast(solve_result, B, C);
- #else
- const bool status = auxlib::solve_square_fast(solve_result, B, C);
- #endif
-
- if(status == false)
- {
- out.soft_reset();
- arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead");
- return;
- }
-
- const partial_unwrap_check<T1> tmp1(X.A.A, out);
-
- const typename partial_unwrap_check<T1>::stored_type& A = tmp1.M;
-
- const bool use_alpha = partial_unwrap_check<T1>::do_times;
- const eT alpha = use_alpha ? tmp1.get_val() : eT(0);
-
- glue_times::apply
- <
- eT,
- partial_unwrap_check<T1>::do_trans,
- false,
- partial_unwrap_check<T1>::do_times
- >
- (out, A, solve_result, alpha);
-
- return;
- }
-
-
- glue_times_redirect3_helper<false>::apply(out, X);
- }
- template<uword N>
- template<typename T1, typename T2>
- arma_hot
- inline
- void
- glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
- {
- arma_extra_debug_sigprint();
-
- typedef typename T1::elem_type eT;
-
- const partial_unwrap<T1> tmp1(X.A);
- const partial_unwrap<T2> tmp2(X.B);
-
- const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
- const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
-
- const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
- const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
-
- const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out);
-
- if(alias == false)
- {
- glue_times::apply
- <
- eT,
- partial_unwrap<T1>::do_trans,
- partial_unwrap<T2>::do_trans,
- (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times)
- >
- (out, A, B, alpha);
- }
- else
- {
- Mat<eT> tmp;
-
- glue_times::apply
- <
- eT,
- partial_unwrap<T1>::do_trans,
- partial_unwrap<T2>::do_trans,
- (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times)
- >
- (tmp, A, B, alpha);
-
- out.steal_mem(tmp);
- }
- }
- template<typename T1, typename T2>
- arma_hot
- inline
- void
- glue_times_redirect<2>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
- {
- arma_extra_debug_sigprint();
-
- typedef typename T1::elem_type eT;
-
- glue_times_redirect2_helper< is_supported_blas_type<eT>::value >::apply(out, X);
- }
- template<typename T1, typename T2, typename T3>
- arma_hot
- inline
- void
- glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
- {
- arma_extra_debug_sigprint();
-
- typedef typename T1::elem_type eT;
-
- glue_times_redirect3_helper< is_supported_blas_type<eT>::value >::apply(out, X);
- }
- template<typename T1, typename T2, typename T3, typename T4>
- arma_hot
- inline
- void
- glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X)
- {
- arma_extra_debug_sigprint();
-
- typedef typename T1::elem_type eT;
-
- // there is exactly 4 objects
- // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B
-
- const partial_unwrap<T1> tmp1(X.A.A.A);
- const partial_unwrap<T2> tmp2(X.A.A.B);
- const partial_unwrap<T3> tmp3(X.A.B );
- const partial_unwrap<T4> tmp4(X.B );
-
- const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
- const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
- const typename partial_unwrap<T3>::stored_type& C = tmp3.M;
- const typename partial_unwrap<T4>::stored_type& D = tmp4.M;
-
- const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times || partial_unwrap<T4>::do_times;
- const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0);
-
- const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out) || tmp3.is_alias(out) || tmp4.is_alias(out);
-
- if(alias == false)
- {
- glue_times::apply
- <
- eT,
- partial_unwrap<T1>::do_trans,
- partial_unwrap<T2>::do_trans,
- partial_unwrap<T3>::do_trans,
- partial_unwrap<T4>::do_trans,
- (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times || partial_unwrap<T4>::do_times)
- >
- (out, A, B, C, D, alpha);
- }
- else
- {
- Mat<eT> tmp;
-
- glue_times::apply
- <
- eT,
- partial_unwrap<T1>::do_trans,
- partial_unwrap<T2>::do_trans,
- partial_unwrap<T3>::do_trans,
- partial_unwrap<T4>::do_trans,
- (partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times || partial_unwrap<T4>::do_times)
- >
- (tmp, A, B, C, D, alpha);
-
- out.steal_mem(tmp);
- }
- }
- template<typename T1, typename T2>
- arma_hot
- inline
- void
- glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
- {
- arma_extra_debug_sigprint();
-
- const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
-
- arma_extra_debug_print(arma_str::format("N_mat = %d") % N_mat);
-
- glue_times_redirect<N_mat>::apply(out, X);
- }
- template<typename T1>
- arma_hot
- inline
- void
- glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
- {
- arma_extra_debug_sigprint();
-
- out = out * X;
- }
- template<typename T1, typename T2>
- arma_hot
- inline
- void
- glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const sword sign)
- {
- arma_extra_debug_sigprint();
-
- typedef typename T1::elem_type eT;
- typedef typename get_pod_type<eT>::result T;
-
- if( (is_outer_product<T1>::value) || (has_op_inv<T1>::value) || (has_op_inv<T2>::value) || (has_op_inv_sympd<T1>::value) || (has_op_inv_sympd<T2>::value) )
- {
- // partial workaround for corner cases
-
- const Mat<eT> tmp(X);
-
- if(sign > sword(0)) { out += tmp; } else { out -= tmp; }
-
- return;
- }
-
- const partial_unwrap_check<T1> tmp1(X.A, out);
- const partial_unwrap_check<T2> tmp2(X.B, out);
-
- typedef typename partial_unwrap_check<T1>::stored_type TA;
- typedef typename partial_unwrap_check<T2>::stored_type TB;
-
- const TA& A = tmp1.M;
- const TB& B = tmp2.M;
-
- const bool do_trans_A = partial_unwrap_check<T1>::do_trans;
- const bool do_trans_B = partial_unwrap_check<T2>::do_trans;
-
- const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times || (sign < sword(0));
-
- const eT alpha = use_alpha ? ( tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ) ) : eT(0);
-
- arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
-
- const uword result_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols);
- const uword result_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows);
-
- arma_debug_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, ( (sign > sword(0)) ? "addition" : "subtraction" ) );
-
- if(out.n_elem == 0)
- {
- return;
- }
-
-
- if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
- {
- if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
- else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
- else { gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1)); }
- }
- else
- if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
- {
- if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
- else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
- else { gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1)); }
- }
- else
- if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
- {
- if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
- else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
- else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<true, false, true>::apply(out, A, alpha, eT(1)); }
- else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::yes) ) { herk<true, false, true>::apply(out, A, T(0), T(1)); }
- else { gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1)); }
- }
- else
- if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
- {
- if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
- else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
- else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<true, true, true>::apply(out, A, alpha, eT(1)); }
- else { gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1)); }
- }
- else
- if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
- {
- if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
- else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
- else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<false, false, true>::apply(out, A, alpha, eT(1)); }
- else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::yes) ) { herk<false, false, true>::apply(out, A, T(0), T(1)); }
- else { gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1)); }
- }
- else
- if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
- {
- if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
- else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
- else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<false, true, true>::apply(out, A, alpha, eT(1)); }
- else { gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1)); }
- }
- else
- if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
- {
- if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
- else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
- else { gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1)); }
- }
- else
- if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
- {
- if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
- else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
- else { gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1)); }
- }
- }
- template<typename eT, const bool do_trans_A, const bool do_trans_B, typename TA, typename TB>
- arma_inline
- uword
- glue_times::mul_storage_cost(const TA& A, const TB& B)
- {
- const uword final_A_n_rows = (do_trans_A == false) ? ( TA::is_row ? 1 : A.n_rows ) : ( TA::is_col ? 1 : A.n_cols );
- const uword final_B_n_cols = (do_trans_B == false) ? ( TB::is_col ? 1 : B.n_cols ) : ( TB::is_row ? 1 : B.n_rows );
-
- return final_A_n_rows * final_B_n_cols;
- }
- template
- <
- typename eT,
- const bool do_trans_A,
- const bool do_trans_B,
- const bool use_alpha,
- typename TA,
- typename TB
- >
- arma_hot
- inline
- void
- glue_times::apply
- (
- Mat<eT>& out,
- const TA& A,
- const TB& B,
- const eT alpha
- )
- {
- arma_extra_debug_sigprint();
-
- //arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
- arma_debug_assert_trans_mul_size<do_trans_A, do_trans_B>(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
-
- const uword final_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols);
- const uword final_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows);
-
- out.set_size(final_n_rows, final_n_cols);
-
- if( (A.n_elem == 0) || (B.n_elem == 0) )
- {
- out.zeros();
- return;
- }
-
-
- if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
- {
- if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); }
- else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); }
- else { gemm<false, false, false, false>::apply(out, A, B ); }
- }
- else
- if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
- {
- if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); }
- else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); }
- else { gemm<false, false, true, false>::apply(out, A, B, alpha); }
- }
- else
- if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
- {
- if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); }
- else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); }
- else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<true, false, false>::apply(out, A ); }
- else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::yes) ) { herk<true, false, false>::apply(out, A ); }
- else { gemm<true, false, false, false>::apply(out, A, B ); }
- }
- else
- if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
- {
- if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); }
- else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); }
- else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<true, true, false>::apply(out, A, alpha); }
- else { gemm<true, false, true, false>::apply(out, A, B, alpha); }
- }
- else
- if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
- {
- if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); }
- else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); }
- else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<false, false, false>::apply(out, A ); }
- else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::yes) ) { herk<false, false, false>::apply(out, A ); }
- else { gemm<false, true, false, false>::apply(out, A, B ); }
- }
- else
- if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
- {
- if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); }
- else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); }
- else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx<eT>::no) ) { syrk<false, true, false>::apply(out, A, alpha); }
- else { gemm<false, true, true, false>::apply(out, A, B, alpha); }
- }
- else
- if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
- {
- if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); }
- else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); }
- else { gemm<true, true, false, false>::apply(out, A, B ); }
- }
- else
- if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
- {
- if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx<eT>::no) ) { gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); }
- else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx<eT>::no) ) { gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); }
- else { gemm<true, true, true, false>::apply(out, A, B, alpha); }
- }
- }
- template
- <
- typename eT,
- const bool do_trans_A,
- const bool do_trans_B,
- const bool do_trans_C,
- const bool use_alpha,
- typename TA,
- typename TB,
- typename TC
- >
- arma_hot
- inline
- void
- glue_times::apply
- (
- Mat<eT>& out,
- const TA& A,
- const TB& B,
- const TC& C,
- const eT alpha
- )
- {
- arma_extra_debug_sigprint();
-
- Mat<eT> tmp;
-
- const uword storage_cost_AB = glue_times::mul_storage_cost<eT, do_trans_A, do_trans_B>(A, B);
- const uword storage_cost_BC = glue_times::mul_storage_cost<eT, do_trans_B, do_trans_C>(B, C);
-
- if(storage_cost_AB <= storage_cost_BC)
- {
- // out = (A*B)*C
-
- glue_times::apply<eT, do_trans_A, do_trans_B, use_alpha>(tmp, A, B, alpha);
- glue_times::apply<eT, false, do_trans_C, false >(out, tmp, C, eT(0));
- }
- else
- {
- // out = A*(B*C)
-
- glue_times::apply<eT, do_trans_B, do_trans_C, use_alpha>(tmp, B, C, alpha);
- glue_times::apply<eT, do_trans_A, false, false >(out, A, tmp, eT(0));
- }
- }
- template
- <
- typename eT,
- const bool do_trans_A,
- const bool do_trans_B,
- const bool do_trans_C,
- const bool do_trans_D,
- const bool use_alpha,
- typename TA,
- typename TB,
- typename TC,
- typename TD
- >
- arma_hot
- inline
- void
- glue_times::apply
- (
- Mat<eT>& out,
- const TA& A,
- const TB& B,
- const TC& C,
- const TD& D,
- const eT alpha
- )
- {
- arma_extra_debug_sigprint();
-
- Mat<eT> tmp;
-
- const uword storage_cost_AC = glue_times::mul_storage_cost<eT, do_trans_A, do_trans_C>(A, C);
- const uword storage_cost_BD = glue_times::mul_storage_cost<eT, do_trans_B, do_trans_D>(B, D);
-
- if(storage_cost_AC <= storage_cost_BD)
- {
- // out = (A*B*C)*D
-
- glue_times::apply<eT, do_trans_A, do_trans_B, do_trans_C, use_alpha>(tmp, A, B, C, alpha);
-
- glue_times::apply<eT, false, do_trans_D, false>(out, tmp, D, eT(0));
- }
- else
- {
- // out = A*(B*C*D)
-
- glue_times::apply<eT, do_trans_B, do_trans_C, do_trans_D, use_alpha>(tmp, B, C, D, alpha);
-
- glue_times::apply<eT, do_trans_A, false, false>(out, A, tmp, eT(0));
- }
- }
- //
- // glue_times_diag
- template<typename T1, typename T2>
- arma_hot
- inline
- void
- glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X)
- {
- arma_extra_debug_sigprint();
-
- typedef typename T1::elem_type eT;
-
- const strip_diagmat<T1> S1(X.A);
- const strip_diagmat<T2> S2(X.B);
-
- typedef typename strip_diagmat<T1>::stored_type T1_stripped;
- typedef typename strip_diagmat<T2>::stored_type T2_stripped;
-
- if( (strip_diagmat<T1>::do_diagmat == true) && (strip_diagmat<T2>::do_diagmat == false) )
- {
- arma_extra_debug_print("glue_times_diag::apply(): diagmat(A) * B");
-
- const diagmat_proxy_check<T1_stripped> A(S1.M, out);
-
- const unwrap_check<T2> tmp(X.B, out);
- const Mat<eT>& B = tmp.M;
-
- const uword A_n_rows = A.n_rows;
- const uword A_n_cols = A.n_cols;
- const uword A_length = (std::min)(A_n_rows, A_n_cols);
-
- const uword B_n_rows = B.n_rows;
- const uword B_n_cols = B.n_cols;
-
- arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication");
-
- out.zeros(A_n_rows, B_n_cols);
-
- for(uword col=0; col < B_n_cols; ++col)
- {
- eT* out_coldata = out.colptr(col);
- const eT* B_coldata = B.colptr(col);
-
- for(uword i=0; i < A_length; ++i)
- {
- out_coldata[i] = A[i] * B_coldata[i];
- }
- }
- }
- else
- if( (strip_diagmat<T1>::do_diagmat == false) && (strip_diagmat<T2>::do_diagmat == true) )
- {
- arma_extra_debug_print("glue_times_diag::apply(): A * diagmat(B)");
-
- const unwrap_check<T1> tmp(X.A, out);
- const Mat<eT>& A = tmp.M;
-
- const diagmat_proxy_check<T2_stripped> B(S2.M, out);
-
- const uword A_n_rows = A.n_rows;
- const uword A_n_cols = A.n_cols;
-
- const uword B_n_rows = B.n_rows;
- const uword B_n_cols = B.n_cols;
- const uword B_length = (std::min)(B_n_rows, B_n_cols);
-
- arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication");
-
- out.zeros(A_n_rows, B_n_cols);
-
- for(uword col=0; col < B_length; ++col)
- {
- const eT val = B[col];
-
- eT* out_coldata = out.colptr(col);
- const eT* A_coldata = A.colptr(col);
-
- for(uword i=0; i < A_n_rows; ++i)
- {
- out_coldata[i] = A_coldata[i] * val;
- }
- }
- }
- else
- if( (strip_diagmat<T1>::do_diagmat == true) && (strip_diagmat<T2>::do_diagmat == true) )
- {
- arma_extra_debug_print("glue_times_diag::apply(): diagmat(A) * diagmat(B)");
-
- const diagmat_proxy_check<T1_stripped> A(S1.M, out);
- const diagmat_proxy_check<T2_stripped> B(S2.M, out);
-
- arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
-
- out.zeros(A.n_rows, B.n_cols);
-
- const uword A_length = (std::min)(A.n_rows, A.n_cols);
- const uword B_length = (std::min)(B.n_rows, B.n_cols);
-
- const uword N = (std::min)(A_length, B_length);
-
- for(uword i=0; i < N; ++i)
- {
- out.at(i,i) = A[i] * B[i];
- }
- }
- }
- //! @}
|