123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733 |
- // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
- // Copyright 2008-2016 National ICT Australia (NICTA)
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- // ------------------------------------------------------------------------
- //! \addtogroup gmm_full
- //! @{
- namespace gmm_priv
- {
- template<typename eT>
- inline
- gmm_full<eT>::~gmm_full()
- {
- arma_extra_debug_sigprint_this(this);
-
- arma_type_check(( (is_same_type<eT,float>::value == false) && (is_same_type<eT,double>::value == false) ));
- }
- template<typename eT>
- inline
- gmm_full<eT>::gmm_full()
- {
- arma_extra_debug_sigprint_this(this);
- }
- template<typename eT>
- inline
- gmm_full<eT>::gmm_full(const gmm_full<eT>& x)
- {
- arma_extra_debug_sigprint_this(this);
-
- init(x);
- }
- template<typename eT>
- inline
- gmm_full<eT>&
- gmm_full<eT>::operator=(const gmm_full<eT>& x)
- {
- arma_extra_debug_sigprint();
-
- init(x);
-
- return *this;
- }
- template<typename eT>
- inline
- gmm_full<eT>::gmm_full(const gmm_diag<eT>& x)
- {
- arma_extra_debug_sigprint_this(this);
-
- init(x);
- }
- template<typename eT>
- inline
- gmm_full<eT>&
- gmm_full<eT>::operator=(const gmm_diag<eT>& x)
- {
- arma_extra_debug_sigprint();
-
- init(x);
-
- return *this;
- }
- template<typename eT>
- inline
- gmm_full<eT>::gmm_full(const uword in_n_dims, const uword in_n_gaus)
- {
- arma_extra_debug_sigprint_this(this);
-
- init(in_n_dims, in_n_gaus);
- }
- template<typename eT>
- inline
- void
- gmm_full<eT>::reset()
- {
- arma_extra_debug_sigprint();
-
- init(0, 0);
- }
- template<typename eT>
- inline
- void
- gmm_full<eT>::reset(const uword in_n_dims, const uword in_n_gaus)
- {
- arma_extra_debug_sigprint();
-
- init(in_n_dims, in_n_gaus);
- }
- template<typename eT>
- template<typename T1, typename T2, typename T3>
- inline
- void
- gmm_full<eT>::set_params(const Base<eT,T1>& in_means_expr, const BaseCube<eT,T2>& in_fcovs_expr, const Base<eT,T3>& in_hefts_expr)
- {
- arma_extra_debug_sigprint();
-
- const unwrap <T1> tmp1(in_means_expr.get_ref());
- const unwrap_cube<T2> tmp2(in_fcovs_expr.get_ref());
- const unwrap <T3> tmp3(in_hefts_expr.get_ref());
-
- const Mat <eT>& in_means = tmp1.M;
- const Cube<eT>& in_fcovs = tmp2.M;
- const Mat <eT>& in_hefts = tmp3.M;
-
- arma_debug_check
- (
- (in_means.n_cols != in_fcovs.n_slices) || (in_means.n_rows != in_fcovs.n_rows) || (in_fcovs.n_rows != in_fcovs.n_cols) || (in_hefts.n_cols != in_means.n_cols) || (in_hefts.n_rows != 1),
- "gmm_full::set_params(): given parameters have inconsistent and/or wrong sizes"
- );
-
- arma_debug_check( (in_means.is_finite() == false), "gmm_full::set_params(): given means have non-finite values" );
- arma_debug_check( (in_fcovs.is_finite() == false), "gmm_full::set_params(): given fcovs have non-finite values" );
- arma_debug_check( (in_hefts.is_finite() == false), "gmm_full::set_params(): given hefts have non-finite values" );
-
- for(uword g=0; g < in_fcovs.n_slices; ++g)
- {
- arma_debug_check( (any(diagvec(in_fcovs.slice(g)) <= eT(0))), "gmm_full::set_params(): given fcovs have negative or zero values on diagonals" );
- }
-
- arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_full::set_params(): given hefts have negative values" );
-
- const eT s = accu(in_hefts);
-
- arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_full::set_params(): sum of given hefts is not 1" );
-
- access::rw(means) = in_means;
- access::rw(fcovs) = in_fcovs;
- access::rw(hefts) = in_hefts;
-
- init_constants();
- }
- template<typename eT>
- template<typename T1>
- inline
- void
- gmm_full<eT>::set_means(const Base<eT,T1>& in_means_expr)
- {
- arma_extra_debug_sigprint();
-
- const unwrap<T1> tmp(in_means_expr.get_ref());
-
- const Mat<eT>& in_means = tmp.M;
-
- arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_full::set_means(): given means have incompatible size" );
- arma_debug_check( (in_means.is_finite() == false), "gmm_full::set_means(): given means have non-finite values" );
-
- access::rw(means) = in_means;
- }
- template<typename eT>
- template<typename T1>
- inline
- void
- gmm_full<eT>::set_fcovs(const BaseCube<eT,T1>& in_fcovs_expr)
- {
- arma_extra_debug_sigprint();
-
- const unwrap_cube<T1> tmp(in_fcovs_expr.get_ref());
-
- const Cube<eT>& in_fcovs = tmp.M;
-
- arma_debug_check( (arma::size(in_fcovs) != arma::size(fcovs)), "gmm_full::set_fcovs(): given fcovs have incompatible size" );
- arma_debug_check( (in_fcovs.is_finite() == false), "gmm_full::set_fcovs(): given fcovs have non-finite values" );
-
- for(uword i=0; i < in_fcovs.n_slices; ++i)
- {
- arma_debug_check( (any(diagvec(in_fcovs.slice(i)) <= eT(0))), "gmm_full::set_fcovs(): given fcovs have negative or zero values on diagonals" );
- }
-
- access::rw(fcovs) = in_fcovs;
-
- init_constants();
- }
- template<typename eT>
- template<typename T1>
- inline
- void
- gmm_full<eT>::set_hefts(const Base<eT,T1>& in_hefts_expr)
- {
- arma_extra_debug_sigprint();
-
- const unwrap<T1> tmp(in_hefts_expr.get_ref());
-
- const Mat<eT>& in_hefts = tmp.M;
-
- arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_full::set_hefts(): given hefts have incompatible size" );
- arma_debug_check( (in_hefts.is_finite() == false), "gmm_full::set_hefts(): given hefts have non-finite values" );
- arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_full::set_hefts(): given hefts have negative values" );
-
- const eT s = accu(in_hefts);
-
- arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_full::set_hefts(): sum of given hefts is not 1" );
-
- // make sure all hefts are positive and non-zero
-
- const eT* in_hefts_mem = in_hefts.memptr();
- eT* hefts_mem = access::rw(hefts).memptr();
-
- for(uword i=0; i < hefts.n_elem; ++i)
- {
- hefts_mem[i] = (std::max)( in_hefts_mem[i], std::numeric_limits<eT>::min() );
- }
-
- access::rw(hefts) /= accu(hefts);
-
- log_hefts = log(hefts);
- }
- template<typename eT>
- inline
- uword
- gmm_full<eT>::n_dims() const
- {
- return means.n_rows;
- }
- template<typename eT>
- inline
- uword
- gmm_full<eT>::n_gaus() const
- {
- return means.n_cols;
- }
- template<typename eT>
- inline
- bool
- gmm_full<eT>::load(const std::string name)
- {
- arma_extra_debug_sigprint();
-
- field< Mat<eT> > storage;
-
- bool status = storage.load(name, arma_binary);
-
- if( (status == false) || (storage.n_elem < 2) )
- {
- reset();
- arma_debug_warn("gmm_full::load(): problem with loading or incompatible format");
- return false;
- }
-
- uword count = 0;
-
- const Mat<eT>& storage_means = storage(count); ++count;
- const Mat<eT>& storage_hefts = storage(count); ++count;
-
- const uword N_dims = storage_means.n_rows;
- const uword N_gaus = storage_means.n_cols;
-
- if( (storage.n_elem != (N_gaus + 2)) || (storage_hefts.n_rows != 1) || (storage_hefts.n_cols != N_gaus) )
- {
- reset();
- arma_debug_warn("gmm_full::load(): incompatible format");
- return false;
- }
-
- reset(N_dims, N_gaus);
-
- access::rw(means) = storage_means;
- access::rw(hefts) = storage_hefts;
-
- for(uword g=0; g < N_gaus; ++g)
- {
- const Mat<eT>& storage_fcov = storage(count); ++count;
-
- if( (storage_fcov.n_rows != N_dims) || (storage_fcov.n_cols != N_dims) )
- {
- reset();
- arma_debug_warn("gmm_full::load(): incompatible format");
- return false;
- }
-
- access::rw(fcovs).slice(g) = storage_fcov;
- }
-
- init_constants();
-
- return true;
- }
- template<typename eT>
- inline
- bool
- gmm_full<eT>::save(const std::string name) const
- {
- arma_extra_debug_sigprint();
-
- const uword N_gaus = means.n_cols;
-
- field< Mat<eT> > storage(2 + N_gaus);
-
- uword count = 0;
-
- storage(count) = means; ++count;
- storage(count) = hefts; ++count;
-
- for(uword g=0; g < N_gaus; ++g)
- {
- storage(count) = fcovs.slice(g); ++count;
- }
-
- const bool status = storage.save(name, arma_binary);
-
- return status;
- }
- template<typename eT>
- inline
- Col<eT>
- gmm_full<eT>::generate() const
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- Col<eT> out( (N_gaus > 0) ? N_dims : uword(0) );
- Col<eT> tmp( (N_gaus > 0) ? N_dims : uword(0), fill::randn );
-
- if(N_gaus > 0)
- {
- const double val = randu<double>();
-
- double csum = double(0);
- uword gaus_id = 0;
-
- for(uword j=0; j < N_gaus; ++j)
- {
- csum += hefts[j];
-
- if(val <= csum) { gaus_id = j; break; }
- }
-
- out = chol_fcovs.slice(gaus_id) * tmp;
- out += means.col(gaus_id);
- }
-
- return out;
- }
- template<typename eT>
- inline
- Mat<eT>
- gmm_full<eT>::generate(const uword N_vec) const
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- Mat<eT> out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec );
- Mat<eT> tmp( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn );
-
- if(N_gaus > 0)
- {
- const eT* hefts_mem = hefts.memptr();
-
- for(uword i=0; i < N_vec; ++i)
- {
- const double val = randu<double>();
-
- double csum = double(0);
- uword gaus_id = 0;
-
- for(uword j=0; j < N_gaus; ++j)
- {
- csum += hefts_mem[j];
-
- if(val <= csum) { gaus_id = j; break; }
- }
-
- Col<eT> out_vec(out.colptr(i), N_dims, false, true);
- Col<eT> tmp_vec(tmp.colptr(i), N_dims, false, true);
-
- out_vec = chol_fcovs.slice(gaus_id) * tmp_vec;
- out_vec += means.col(gaus_id);
- }
- }
-
- return out;
- }
- template<typename eT>
- template<typename T1>
- inline
- eT
- gmm_full<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
- {
- arma_extra_debug_sigprint();
- arma_ignore(junk1);
- arma_ignore(junk2);
-
- const uword N_dims = means.n_rows;
-
- const quasi_unwrap<T1> U(expr);
-
- arma_debug_check( (U.M.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
-
- return internal_scalar_log_p( U.M.memptr() );
- }
- template<typename eT>
- template<typename T1>
- inline
- eT
- gmm_full<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
- {
- arma_extra_debug_sigprint();
- arma_ignore(junk2);
-
- const uword N_dims = means.n_rows;
-
- const quasi_unwrap<T1> U(expr);
-
- arma_debug_check( (U.M.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
- arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::log_p(): specified gaussian is out of range" );
-
- return internal_scalar_log_p( U.M.memptr(), gaus_id );
- }
- template<typename eT>
- template<typename T1>
- inline
- Row<eT>
- gmm_full<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
- {
- arma_extra_debug_sigprint();
- arma_ignore(junk1);
- arma_ignore(junk2);
-
- const quasi_unwrap<T1> tmp(expr);
-
- const Mat<eT>& X = tmp.M;
-
- return internal_vec_log_p(X);
- }
- template<typename eT>
- template<typename T1>
- inline
- Row<eT>
- gmm_full<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
- {
- arma_extra_debug_sigprint();
- arma_ignore(junk2);
-
- const quasi_unwrap<T1> tmp(expr);
-
- const Mat<eT>& X = tmp.M;
-
- return internal_vec_log_p(X, gaus_id);
- }
- template<typename eT>
- template<typename T1>
- inline
- eT
- gmm_full<eT>::sum_log_p(const Base<eT,T1>& expr) const
- {
- arma_extra_debug_sigprint();
-
- const quasi_unwrap<T1> tmp(expr.get_ref());
-
- const Mat<eT>& X = tmp.M;
-
- return internal_sum_log_p(X);
- }
- template<typename eT>
- template<typename T1>
- inline
- eT
- gmm_full<eT>::sum_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
- {
- arma_extra_debug_sigprint();
-
- const quasi_unwrap<T1> tmp(expr.get_ref());
-
- const Mat<eT>& X = tmp.M;
-
- return internal_sum_log_p(X, gaus_id);
- }
- template<typename eT>
- template<typename T1>
- inline
- eT
- gmm_full<eT>::avg_log_p(const Base<eT,T1>& expr) const
- {
- arma_extra_debug_sigprint();
-
- const quasi_unwrap<T1> tmp(expr.get_ref());
-
- const Mat<eT>& X = tmp.M;
-
- return internal_avg_log_p(X);
- }
- template<typename eT>
- template<typename T1>
- inline
- eT
- gmm_full<eT>::avg_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
- {
- arma_extra_debug_sigprint();
-
- const quasi_unwrap<T1> tmp(expr.get_ref());
-
- const Mat<eT>& X = tmp.M;
-
- return internal_avg_log_p(X, gaus_id);
- }
- template<typename eT>
- template<typename T1>
- inline
- uword
- gmm_full<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk) const
- {
- arma_extra_debug_sigprint();
- arma_ignore(junk);
-
- const quasi_unwrap<T1> tmp(expr);
-
- const Mat<eT>& X = tmp.M;
-
- return internal_scalar_assign(X, dist);
- }
- template<typename eT>
- template<typename T1>
- inline
- urowvec
- gmm_full<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk) const
- {
- arma_extra_debug_sigprint();
- arma_ignore(junk);
-
- urowvec out;
-
- const quasi_unwrap<T1> tmp(expr);
-
- const Mat<eT>& X = tmp.M;
-
- internal_vec_assign(out, X, dist);
-
- return out;
- }
- template<typename eT>
- template<typename T1>
- inline
- urowvec
- gmm_full<eT>::raw_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
- {
- arma_extra_debug_sigprint();
-
- const unwrap<T1> tmp(expr.get_ref());
- const Mat<eT>& X = tmp.M;
-
- arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::raw_hist(): incompatible dimensions" );
-
- arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_full::raw_hist(): unsupported distance mode" );
-
- urowvec hist;
-
- internal_raw_hist(hist, X, dist_mode);
-
- return hist;
- }
- template<typename eT>
- template<typename T1>
- inline
- Row<eT>
- gmm_full<eT>::norm_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
- {
- arma_extra_debug_sigprint();
-
- const unwrap<T1> tmp(expr.get_ref());
- const Mat<eT>& X = tmp.M;
-
- arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::norm_hist(): incompatible dimensions" );
-
- arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_full::norm_hist(): unsupported distance mode" );
-
- urowvec hist;
-
- internal_raw_hist(hist, X, dist_mode);
-
- const uword hist_n_elem = hist.n_elem;
- const uword* hist_mem = hist.memptr();
-
- eT acc = eT(0);
- for(uword i=0; i<hist_n_elem; ++i) { acc += eT(hist_mem[i]); }
-
- if(acc == eT(0)) { acc = eT(1); }
-
- Row<eT> out(hist_n_elem);
-
- eT* out_mem = out.memptr();
-
- for(uword i=0; i<hist_n_elem; ++i) { out_mem[i] = eT(hist_mem[i]) / acc; }
-
- return out;
- }
- template<typename eT>
- template<typename T1>
- inline
- bool
- gmm_full<eT>::learn
- (
- const Base<eT,T1>& data,
- const uword N_gaus,
- const gmm_dist_mode& dist_mode,
- const gmm_seed_mode& seed_mode,
- const uword km_iter,
- const uword em_iter,
- const eT var_floor,
- const bool print_mode
- )
- {
- arma_extra_debug_sigprint();
-
- const bool dist_mode_ok = (dist_mode == eucl_dist) || (dist_mode == maha_dist);
-
- const bool seed_mode_ok = \
- (seed_mode == keep_existing)
- || (seed_mode == static_subset)
- || (seed_mode == static_spread)
- || (seed_mode == random_subset)
- || (seed_mode == random_spread);
-
- arma_debug_check( (dist_mode_ok == false), "gmm_full::learn(): dist_mode must be eucl_dist or maha_dist" );
- arma_debug_check( (seed_mode_ok == false), "gmm_full::learn(): unknown seed_mode" );
- arma_debug_check( (var_floor < eT(0) ), "gmm_full::learn(): variance floor is negative" );
-
- const unwrap<T1> tmp_X(data.get_ref());
- const Mat<eT>& X = tmp_X.M;
-
- if(X.is_empty() ) { arma_debug_warn("gmm_full::learn(): given matrix is empty" ); return false; }
- if(X.is_finite() == false) { arma_debug_warn("gmm_full::learn(): given matrix has non-finite values"); return false; }
-
- if(N_gaus == 0) { reset(); return true; }
-
- if(dist_mode == maha_dist)
- {
- mah_aux = var(X,1,1);
-
- const uword mah_aux_n_elem = mah_aux.n_elem;
- eT* mah_aux_mem = mah_aux.memptr();
-
- for(uword i=0; i < mah_aux_n_elem; ++i)
- {
- const eT val = mah_aux_mem[i];
-
- mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1);
- }
- }
-
-
- // copy current model, in case of failure by k-means and/or EM
-
- const gmm_full<eT> orig = (*this);
-
-
- // initial means
-
- if(seed_mode == keep_existing)
- {
- if(means.is_empty() ) { arma_debug_warn("gmm_full::learn(): no existing means" ); return false; }
- if(X.n_rows != means.n_rows) { arma_debug_warn("gmm_full::learn(): dimensionality mismatch"); return false; }
-
- // TODO: also check for number of vectors?
- }
- else
- {
- if(X.n_cols < N_gaus) { arma_debug_warn("gmm_full::learn(): number of vectors is less than number of gaussians"); return false; }
-
- reset(X.n_rows, N_gaus);
-
- if(print_mode) { get_cout_stream() << "gmm_full::learn(): generating initial means\n"; get_cout_stream().flush(); }
-
- if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mode); }
- else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mode); }
- }
-
-
- // k-means
-
- if(km_iter > 0)
- {
- const arma_ostream_state stream_state(get_cout_stream());
-
- bool status = false;
-
- if(dist_mode == eucl_dist) { status = km_iterate<1>(X, km_iter, print_mode); }
- else if(dist_mode == maha_dist) { status = km_iterate<2>(X, km_iter, print_mode); }
-
- stream_state.restore(get_cout_stream());
-
- if(status == false) { arma_debug_warn("gmm_full::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; }
- }
-
-
- // initial fcovs
-
- const eT var_floor_actual = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_limits<eT>::min();
-
- if(seed_mode != keep_existing)
- {
- if(print_mode) { get_cout_stream() << "gmm_full::learn(): generating initial covariances\n"; get_cout_stream().flush(); }
-
- if(dist_mode == eucl_dist) { generate_initial_params<1>(X, var_floor_actual); }
- else if(dist_mode == maha_dist) { generate_initial_params<2>(X, var_floor_actual); }
- }
-
-
- // EM algorithm
-
- if(em_iter > 0)
- {
- const arma_ostream_state stream_state(get_cout_stream());
-
- const bool status = em_iterate(X, em_iter, var_floor_actual, print_mode);
-
- stream_state.restore(get_cout_stream());
-
- if(status == false) { arma_debug_warn("gmm_full::learn(): EM algorithm failed"); init(orig); return false; }
- }
-
- mah_aux.reset();
-
- init_constants();
-
- return true;
- }
- //
- //
- //
- template<typename eT>
- inline
- void
- gmm_full<eT>::init(const gmm_full<eT>& x)
- {
- arma_extra_debug_sigprint();
-
- gmm_full<eT>& t = *this;
-
- if(&t != &x)
- {
- access::rw(t.means) = x.means;
- access::rw(t.fcovs) = x.fcovs;
- access::rw(t.hefts) = x.hefts;
-
- init_constants();
- }
- }
- template<typename eT>
- inline
- void
- gmm_full<eT>::init(const gmm_diag<eT>& x)
- {
- arma_extra_debug_sigprint();
-
- access::rw(hefts) = x.hefts;
- access::rw(means) = x.means;
-
- const uword N_dims = x.means.n_rows;
- const uword N_gaus = x.means.n_cols;
-
- access::rw(fcovs).zeros(N_dims,N_dims,N_gaus);
-
- for(uword g=0; g < N_gaus; ++g)
- {
- Mat<eT>& fcov = access::rw(fcovs).slice(g);
-
- const eT* dcov_mem = x.dcovs.colptr(g);
-
- for(uword d=0; d < N_dims; ++d)
- {
- fcov.at(d,d) = dcov_mem[d];
- }
- }
-
- init_constants();
- }
- template<typename eT>
- inline
- void
- gmm_full<eT>::init(const uword in_n_dims, const uword in_n_gaus)
- {
- arma_extra_debug_sigprint();
-
- access::rw(means).zeros(in_n_dims, in_n_gaus);
-
- access::rw(fcovs).zeros(in_n_dims, in_n_dims, in_n_gaus);
-
- for(uword g=0; g < in_n_gaus; ++g)
- {
- access::rw(fcovs).slice(g).diag().ones();
- }
-
- access::rw(hefts).set_size(in_n_gaus);
- access::rw(hefts).fill(eT(1) / eT(in_n_gaus));
-
- init_constants();
- }
- template<typename eT>
- inline
- void
- gmm_full<eT>::init_constants(const bool calc_chol)
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- const eT tmp = (eT(N_dims)/eT(2)) * std::log(eT(2) * Datum<eT>::pi);
-
- //
-
- inv_fcovs.copy_size(fcovs);
- log_det_etc.set_size(N_gaus);
-
- Mat<eT> tmp_inv;
-
- for(uword g=0; g < N_gaus; ++g)
- {
- const Mat<eT>& fcov = fcovs.slice(g);
- Mat<eT>& inv_fcov = inv_fcovs.slice(g);
-
- //const bool inv_ok = auxlib::inv(tmp_inv, fcov);
- const bool inv_ok = auxlib::inv_sympd(tmp_inv, fcov);
-
- eT log_det_val = eT(0);
- eT log_det_sign = eT(0);
-
- log_det(log_det_val, log_det_sign, fcov);
-
- const bool log_det_ok = ( (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) );
-
- if(inv_ok && log_det_ok)
- {
- inv_fcov = tmp_inv;
- }
- else
- {
- // last resort: treat the covariance matrix as diagonal
-
- inv_fcov.zeros();
-
- log_det_val = eT(0);
-
- for(uword d=0; d < N_dims; ++d)
- {
- const eT sanitised_val = (std::max)( eT(fcov.at(d,d)), eT(std::numeric_limits<eT>::min()) );
-
- inv_fcov.at(d,d) = eT(1) / sanitised_val;
-
- log_det_val += std::log(sanitised_val);
- }
- }
-
- log_det_etc[g] = eT(-1) * ( tmp + eT(0.5) * log_det_val );
- }
-
- //
-
- eT* hefts_mem = access::rw(hefts).memptr();
-
- for(uword g=0; g < N_gaus; ++g)
- {
- hefts_mem[g] = (std::max)( hefts_mem[g], std::numeric_limits<eT>::min() );
- }
-
- log_hefts = log(hefts);
-
-
- if(calc_chol)
- {
- chol_fcovs.copy_size(fcovs);
-
- Mat<eT> tmp_chol;
-
- for(uword g=0; g < N_gaus; ++g)
- {
- const Mat<eT>& fcov = fcovs.slice(g);
- Mat<eT>& chol_fcov = chol_fcovs.slice(g);
-
- const uword chol_layout = 1; // indicates "lower"
-
- const bool chol_ok = op_chol::apply_direct(tmp_chol, fcov, chol_layout);
-
- if(chol_ok)
- {
- chol_fcov = tmp_chol;
- }
- else
- {
- // last resort: treat the covariance matrix as diagonal
-
- chol_fcov.zeros();
-
- for(uword d=0; d < N_dims; ++d)
- {
- const eT sanitised_val = (std::max)( eT(fcov.at(d,d)), eT(std::numeric_limits<eT>::min()) );
-
- chol_fcov.at(d,d) = std::sqrt(sanitised_val);
- }
- }
- }
- }
- }
- template<typename eT>
- inline
- umat
- gmm_full<eT>::internal_gen_boundaries(const uword N) const
- {
- arma_extra_debug_sigprint();
-
- #if defined(ARMA_USE_OPENMP)
- const uword n_threads_avail = uword(omp_get_max_threads());
- const uword n_threads = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1;
- #else
- static const uword n_threads = 1;
- #endif
-
- // get_cout_stream() << "gmm_full::internal_gen_boundaries(): n_threads: " << n_threads << '\n';
-
- umat boundaries(2, n_threads);
-
- if(N > 0)
- {
- const uword chunk_size = N / n_threads;
-
- uword count = 0;
-
- for(uword t=0; t<n_threads; t++)
- {
- boundaries.at(0,t) = count;
-
- count += chunk_size;
-
- boundaries.at(1,t) = count-1;
- }
-
- boundaries.at(1,n_threads-1) = N - 1;
- }
- else
- {
- boundaries.zeros();
- }
-
- // get_cout_stream() << "gmm_full::internal_gen_boundaries(): boundaries: " << '\n' << boundaries << '\n';
-
- return boundaries;
- }
- template<typename eT>
- inline
- eT
- gmm_full<eT>::internal_scalar_log_p(const eT* x) const
- {
- arma_extra_debug_sigprint();
-
- const eT* log_hefts_mem = log_hefts.mem;
-
- const uword N_gaus = means.n_cols;
-
- if(N_gaus > 0)
- {
- eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0];
-
- for(uword g=1; g < N_gaus; ++g)
- {
- const eT log_val = internal_scalar_log_p(x, g) + log_hefts_mem[g];
-
- log_sum = log_add_exp(log_sum, log_val);
- }
-
- return log_sum;
- }
- else
- {
- return -Datum<eT>::inf;
- }
- }
- template<typename eT>
- inline
- eT
- gmm_full<eT>::internal_scalar_log_p(const eT* x, const uword g) const
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const eT* mean_mem = means.colptr(g);
-
- eT outer_acc = eT(0);
-
- const eT* inv_fcov_coldata = inv_fcovs.slice(g).memptr();
-
- for(uword i=0; i < N_dims; ++i)
- {
- eT inner_acc = eT(0);
-
- for(uword j=0; j < N_dims; ++j)
- {
- inner_acc += (x[j] - mean_mem[j]) * inv_fcov_coldata[j];
- }
-
- inv_fcov_coldata += N_dims;
-
- outer_acc += inner_acc * (x[i] - mean_mem[i]);
- }
-
- return eT(-0.5)*outer_acc + log_det_etc.mem[g];
- }
- template<typename eT>
- inline
- Row<eT>
- gmm_full<eT>::internal_vec_log_p(const Mat<eT>& X) const
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_samples = X.n_cols;
-
- arma_debug_check( (X.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
-
- Row<eT> out(N_samples);
-
- if(N_samples > 0)
- {
- #if defined(ARMA_USE_OPENMP)
- {
- const umat boundaries = internal_gen_boundaries(N_samples);
-
- const uword n_threads = boundaries.n_cols;
-
- #pragma omp parallel for schedule(static)
- for(uword t=0; t < n_threads; ++t)
- {
- const uword start_index = boundaries.at(0,t);
- const uword end_index = boundaries.at(1,t);
-
- eT* out_mem = out.memptr();
-
- for(uword i=start_index; i <= end_index; ++i)
- {
- out_mem[i] = internal_scalar_log_p( X.colptr(i) );
- }
- }
- }
- #else
- {
- eT* out_mem = out.memptr();
-
- for(uword i=0; i < N_samples; ++i)
- {
- out_mem[i] = internal_scalar_log_p( X.colptr(i) );
- }
- }
- #endif
- }
-
- return out;
- }
- template<typename eT>
- inline
- Row<eT>
- gmm_full<eT>::internal_vec_log_p(const Mat<eT>& X, const uword gaus_id) const
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_samples = X.n_cols;
-
- arma_debug_check( (X.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
- arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::log_p(): specified gaussian is out of range" );
-
- Row<eT> out(N_samples);
-
- if(N_samples > 0)
- {
- #if defined(ARMA_USE_OPENMP)
- {
- const umat boundaries = internal_gen_boundaries(N_samples);
-
- const uword n_threads = boundaries.n_cols;
-
- #pragma omp parallel for schedule(static)
- for(uword t=0; t < n_threads; ++t)
- {
- const uword start_index = boundaries.at(0,t);
- const uword end_index = boundaries.at(1,t);
-
- eT* out_mem = out.memptr();
-
- for(uword i=start_index; i <= end_index; ++i)
- {
- out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
- }
- }
- }
- #else
- {
- eT* out_mem = out.memptr();
-
- for(uword i=0; i < N_samples; ++i)
- {
- out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
- }
- }
- #endif
- }
-
- return out;
- }
- template<typename eT>
- inline
- eT
- gmm_full<eT>::internal_sum_log_p(const Mat<eT>& X) const
- {
- arma_extra_debug_sigprint();
-
- arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::sum_log_p(): incompatible dimensions" );
-
- const uword N = X.n_cols;
-
- if(N == 0) { return (-Datum<eT>::inf); }
-
-
- #if defined(ARMA_USE_OPENMP)
- {
- const umat boundaries = internal_gen_boundaries(N);
-
- const uword n_threads = boundaries.n_cols;
-
- Col<eT> t_accs(n_threads, fill::zeros);
-
- #pragma omp parallel for schedule(static)
- for(uword t=0; t < n_threads; ++t)
- {
- const uword start_index = boundaries.at(0,t);
- const uword end_index = boundaries.at(1,t);
-
- eT t_acc = eT(0);
-
- for(uword i=start_index; i <= end_index; ++i)
- {
- t_acc += internal_scalar_log_p( X.colptr(i) );
- }
-
- t_accs[t] = t_acc;
- }
-
- return eT(accu(t_accs));
- }
- #else
- {
- eT acc = eT(0);
-
- for(uword i=0; i<N; ++i)
- {
- acc += internal_scalar_log_p( X.colptr(i) );
- }
-
- return acc;
- }
- #endif
- }
- template<typename eT>
- inline
- eT
- gmm_full<eT>::internal_sum_log_p(const Mat<eT>& X, const uword gaus_id) const
- {
- arma_extra_debug_sigprint();
-
- arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::sum_log_p(): incompatible dimensions" );
- arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::sum_log_p(): specified gaussian is out of range" );
-
- const uword N = X.n_cols;
-
- if(N == 0) { return (-Datum<eT>::inf); }
-
-
- #if defined(ARMA_USE_OPENMP)
- {
- const umat boundaries = internal_gen_boundaries(N);
-
- const uword n_threads = boundaries.n_cols;
-
- Col<eT> t_accs(n_threads, fill::zeros);
-
- #pragma omp parallel for schedule(static)
- for(uword t=0; t < n_threads; ++t)
- {
- const uword start_index = boundaries.at(0,t);
- const uword end_index = boundaries.at(1,t);
-
- eT t_acc = eT(0);
-
- for(uword i=start_index; i <= end_index; ++i)
- {
- t_acc += internal_scalar_log_p( X.colptr(i), gaus_id );
- }
-
- t_accs[t] = t_acc;
- }
-
- return eT(accu(t_accs));
- }
- #else
- {
- eT acc = eT(0);
-
- for(uword i=0; i<N; ++i)
- {
- acc += internal_scalar_log_p( X.colptr(i), gaus_id );
- }
-
- return acc;
- }
- #endif
- }
- template<typename eT>
- inline
- eT
- gmm_full<eT>::internal_avg_log_p(const Mat<eT>& X) const
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_samples = X.n_cols;
-
- arma_debug_check( (X.n_rows != N_dims), "gmm_full::avg_log_p(): incompatible dimensions" );
-
- if(N_samples == 0) { return (-Datum<eT>::inf); }
-
-
- #if defined(ARMA_USE_OPENMP)
- {
- const umat boundaries = internal_gen_boundaries(N_samples);
-
- const uword n_threads = boundaries.n_cols;
-
- field< running_mean_scalar<eT> > t_running_means(n_threads);
-
-
- #pragma omp parallel for schedule(static)
- for(uword t=0; t < n_threads; ++t)
- {
- const uword start_index = boundaries.at(0,t);
- const uword end_index = boundaries.at(1,t);
-
- running_mean_scalar<eT>& current_running_mean = t_running_means[t];
-
- for(uword i=start_index; i <= end_index; ++i)
- {
- current_running_mean( internal_scalar_log_p( X.colptr(i) ) );
- }
- }
-
-
- eT avg = eT(0);
-
- for(uword t=0; t < n_threads; ++t)
- {
- running_mean_scalar<eT>& current_running_mean = t_running_means[t];
-
- const eT w = eT(current_running_mean.count()) / eT(N_samples);
-
- avg += w * current_running_mean.mean();
- }
-
- return avg;
- }
- #else
- {
- running_mean_scalar<eT> running_mean;
-
- for(uword i=0; i < N_samples; ++i)
- {
- running_mean( internal_scalar_log_p( X.colptr(i) ) );
- }
-
- return running_mean.mean();
- }
- #endif
- }
- template<typename eT>
- inline
- eT
- gmm_full<eT>::internal_avg_log_p(const Mat<eT>& X, const uword gaus_id) const
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_samples = X.n_cols;
-
- arma_debug_check( (X.n_rows != N_dims), "gmm_full::avg_log_p(): incompatible dimensions" );
- arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::avg_log_p(): specified gaussian is out of range" );
-
- if(N_samples == 0) { return (-Datum<eT>::inf); }
-
-
- #if defined(ARMA_USE_OPENMP)
- {
- const umat boundaries = internal_gen_boundaries(N_samples);
-
- const uword n_threads = boundaries.n_cols;
-
- field< running_mean_scalar<eT> > t_running_means(n_threads);
-
-
- #pragma omp parallel for schedule(static)
- for(uword t=0; t < n_threads; ++t)
- {
- const uword start_index = boundaries.at(0,t);
- const uword end_index = boundaries.at(1,t);
-
- running_mean_scalar<eT>& current_running_mean = t_running_means[t];
-
- for(uword i=start_index; i <= end_index; ++i)
- {
- current_running_mean( internal_scalar_log_p( X.colptr(i), gaus_id) );
- }
- }
-
-
- eT avg = eT(0);
-
- for(uword t=0; t < n_threads; ++t)
- {
- running_mean_scalar<eT>& current_running_mean = t_running_means[t];
-
- const eT w = eT(current_running_mean.count()) / eT(N_samples);
-
- avg += w * current_running_mean.mean();
- }
-
- return avg;
- }
- #else
- {
- running_mean_scalar<eT> running_mean;
-
- for(uword i=0; i<N_samples; ++i)
- {
- running_mean( internal_scalar_log_p( X.colptr(i), gaus_id ) );
- }
-
- return running_mean.mean();
- }
- #endif
- }
- template<typename eT>
- inline
- uword
- gmm_full<eT>::internal_scalar_assign(const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- arma_debug_check( (X.n_rows != N_dims), "gmm_full::assign(): incompatible dimensions" );
- arma_debug_check( (N_gaus == 0), "gmm_full::assign(): model has no means" );
-
- const eT* X_mem = X.colptr(0);
-
- if(dist_mode == eucl_dist)
- {
- eT best_dist = Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g < N_gaus; ++g)
- {
- const eT tmp_dist = distance<eT,1>::eval(N_dims, X_mem, means.colptr(g), X_mem);
-
- if(tmp_dist <= best_dist)
- {
- best_dist = tmp_dist;
- best_g = g;
- }
- }
-
- return best_g;
- }
- else
- if(dist_mode == prob_dist)
- {
- const eT* log_hefts_mem = log_hefts.memptr();
-
- eT best_p = -Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g < N_gaus; ++g)
- {
- const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g];
-
- if(tmp_p >= best_p)
- {
- best_p = tmp_p;
- best_g = g;
- }
- }
-
- return best_g;
- }
- else
- {
- arma_debug_check(true, "gmm_full::assign(): unsupported distance mode");
- }
-
- return uword(0);
- }
- template<typename eT>
- inline
- void
- gmm_full<eT>::internal_vec_assign(urowvec& out, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- arma_debug_check( (X.n_rows != N_dims), "gmm_full::assign(): incompatible dimensions" );
-
- const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0;
-
- out.set_size(1,X_n_cols);
-
- uword* out_mem = out.memptr();
-
- if(dist_mode == eucl_dist)
- {
- #if defined(ARMA_USE_OPENMP)
- {
- #pragma omp parallel for schedule(static)
- for(uword i=0; i<X_n_cols; ++i)
- {
- const eT* X_colptr = X.colptr(i);
-
- eT best_dist = Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g<N_gaus; ++g)
- {
- const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
-
- if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
- }
-
- out_mem[i] = best_g;
- }
- }
- #else
- {
- for(uword i=0; i<X_n_cols; ++i)
- {
- const eT* X_colptr = X.colptr(i);
-
- eT best_dist = Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g<N_gaus; ++g)
- {
- const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
-
- if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
- }
-
- out_mem[i] = best_g;
- }
- }
- #endif
- }
- else
- if(dist_mode == prob_dist)
- {
- #if defined(ARMA_USE_OPENMP)
- {
- const umat boundaries = internal_gen_boundaries(X_n_cols);
-
- const uword n_threads = boundaries.n_cols;
-
- const eT* log_hefts_mem = log_hefts.memptr();
-
- #pragma omp parallel for schedule(static)
- for(uword t=0; t < n_threads; ++t)
- {
- const uword start_index = boundaries.at(0,t);
- const uword end_index = boundaries.at(1,t);
-
- for(uword i=start_index; i <= end_index; ++i)
- {
- const eT* X_colptr = X.colptr(i);
-
- eT best_p = -Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g<N_gaus; ++g)
- {
- const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
-
- if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
- }
-
- out_mem[i] = best_g;
- }
- }
- }
- #else
- {
- const eT* log_hefts_mem = log_hefts.memptr();
-
- for(uword i=0; i<X_n_cols; ++i)
- {
- const eT* X_colptr = X.colptr(i);
-
- eT best_p = -Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g<N_gaus; ++g)
- {
- const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
-
- if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
- }
-
- out_mem[i] = best_g;
- }
- }
- #endif
- }
- else
- {
- arma_debug_check(true, "gmm_full::assign(): unsupported distance mode");
- }
- }
- template<typename eT>
- inline
- void
- gmm_full<eT>::internal_raw_hist(urowvec& hist, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- const uword X_n_cols = X.n_cols;
-
- hist.zeros(N_gaus);
-
- if(N_gaus == 0) { return; }
-
- #if defined(ARMA_USE_OPENMP)
- {
- const umat boundaries = internal_gen_boundaries(X_n_cols);
-
- const uword n_threads = boundaries.n_cols;
-
- field<urowvec> thread_hist(n_threads);
-
- for(uword t=0; t < n_threads; ++t) { thread_hist(t).zeros(N_gaus); }
-
-
- if(dist_mode == eucl_dist)
- {
- #pragma omp parallel for schedule(static)
- for(uword t=0; t < n_threads; ++t)
- {
- uword* thread_hist_mem = thread_hist(t).memptr();
-
- const uword start_index = boundaries.at(0,t);
- const uword end_index = boundaries.at(1,t);
-
- for(uword i=start_index; i <= end_index; ++i)
- {
- const eT* X_colptr = X.colptr(i);
-
- eT best_dist = Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g < N_gaus; ++g)
- {
- const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
-
- if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
- }
-
- thread_hist_mem[best_g]++;
- }
- }
- }
- else
- if(dist_mode == prob_dist)
- {
- const eT* log_hefts_mem = log_hefts.memptr();
-
- #pragma omp parallel for schedule(static)
- for(uword t=0; t < n_threads; ++t)
- {
- uword* thread_hist_mem = thread_hist(t).memptr();
-
- const uword start_index = boundaries.at(0,t);
- const uword end_index = boundaries.at(1,t);
-
- for(uword i=start_index; i <= end_index; ++i)
- {
- const eT* X_colptr = X.colptr(i);
-
- eT best_p = -Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g < N_gaus; ++g)
- {
- const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
-
- if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
- }
-
- thread_hist_mem[best_g]++;
- }
- }
- }
-
- // reduction
- for(uword t=0; t < n_threads; ++t)
- {
- hist += thread_hist(t);
- }
- }
- #else
- {
- uword* hist_mem = hist.memptr();
-
- if(dist_mode == eucl_dist)
- {
- for(uword i=0; i<X_n_cols; ++i)
- {
- const eT* X_colptr = X.colptr(i);
-
- eT best_dist = Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g < N_gaus; ++g)
- {
- const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
-
- if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
- }
-
- hist_mem[best_g]++;
- }
- }
- else
- if(dist_mode == prob_dist)
- {
- const eT* log_hefts_mem = log_hefts.memptr();
-
- for(uword i=0; i<X_n_cols; ++i)
- {
- const eT* X_colptr = X.colptr(i);
-
- eT best_p = -Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g < N_gaus; ++g)
- {
- const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
-
- if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
- }
-
- hist_mem[best_g]++;
- }
- }
- }
- #endif
- }
- template<typename eT>
- template<uword dist_id>
- inline
- void
- gmm_full<eT>::generate_initial_means(const Mat<eT>& X, const gmm_seed_mode& seed_mode)
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- if( (seed_mode == static_subset) || (seed_mode == random_subset) )
- {
- uvec initial_indices;
-
- if(seed_mode == static_subset) { initial_indices = linspace<uvec>(0, X.n_cols-1, N_gaus); }
- else if(seed_mode == random_subset) { initial_indices = randperm<uvec>(X.n_cols, N_gaus); }
-
- // initial_indices.print("initial_indices:");
-
- access::rw(means) = X.cols(initial_indices);
- }
- else
- if( (seed_mode == static_spread) || (seed_mode == random_spread) )
- {
- // going through all of the samples can be extremely time consuming;
- // instead, if there are enough samples, randomly choose samples with probability 0.1
-
- const bool use_sampling = ((X.n_cols/uword(100)) > N_gaus);
- const uword step = (use_sampling) ? uword(10) : uword(1);
-
- uword start_index = 0;
-
- if(seed_mode == static_spread) { start_index = X.n_cols / 2; }
- else if(seed_mode == random_spread) { start_index = as_scalar(randi<uvec>(1, distr_param(0,X.n_cols-1))); }
-
- access::rw(means).col(0) = X.unsafe_col(start_index);
-
- const eT* mah_aux_mem = mah_aux.memptr();
-
- running_stat<double> rs;
-
- for(uword g=1; g < N_gaus; ++g)
- {
- eT max_dist = eT(0);
- uword best_i = uword(0);
- uword start_i = uword(0);
-
- if(use_sampling)
- {
- uword start_i_proposed = uword(0);
-
- if(seed_mode == static_spread) { start_i_proposed = g % uword(10); }
- if(seed_mode == random_spread) { start_i_proposed = as_scalar(randi<uvec>(1, distr_param(0,9))); }
-
- if(start_i_proposed < X.n_cols) { start_i = start_i_proposed; }
- }
-
-
- for(uword i=start_i; i < X.n_cols; i += step)
- {
- rs.reset();
-
- const eT* X_colptr = X.colptr(i);
-
- bool ignore_i = false;
-
- // find the average distance between sample i and the means so far
- for(uword h = 0; h < g; ++h)
- {
- const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(h), mah_aux_mem);
-
- // ignore sample already selected as a mean
- if(dist == eT(0)) { ignore_i = true; break; }
- else { rs(dist); }
- }
-
- if( (rs.mean() >= max_dist) && (ignore_i == false))
- {
- max_dist = eT(rs.mean()); best_i = i;
- }
- }
-
- // set the mean to the sample that is the furthest away from the means so far
- access::rw(means).col(g) = X.unsafe_col(best_i);
- }
- }
-
- // get_cout_stream() << "generate_initial_means():" << '\n';
- // means.print();
- }
- template<typename eT>
- template<uword dist_id>
- inline
- void
- gmm_full<eT>::generate_initial_params(const Mat<eT>& X, const eT var_floor)
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- const eT* mah_aux_mem = mah_aux.memptr();
-
- const uword X_n_cols = X.n_cols;
-
- if(X_n_cols == 0) { return; }
-
- // as the covariances are calculated via accumulators,
- // the means also need to be calculated via accumulators to ensure numerical consistency
-
- Mat<eT> acc_means(N_dims, N_gaus, fill::zeros);
- Mat<eT> acc_dcovs(N_dims, N_gaus, fill::zeros);
-
- Row<uword> acc_hefts(N_gaus, fill::zeros);
-
- uword* acc_hefts_mem = acc_hefts.memptr();
-
- #if defined(ARMA_USE_OPENMP)
- {
- const umat boundaries = internal_gen_boundaries(X_n_cols);
-
- const uword n_threads = boundaries.n_cols;
-
- field< Mat<eT> > t_acc_means(n_threads);
- field< Mat<eT> > t_acc_dcovs(n_threads);
- field< Row<uword> > t_acc_hefts(n_threads);
-
- for(uword t=0; t < n_threads; ++t)
- {
- t_acc_means(t).zeros(N_dims, N_gaus);
- t_acc_dcovs(t).zeros(N_dims, N_gaus);
- t_acc_hefts(t).zeros(N_gaus);
- }
-
- #pragma omp parallel for schedule(static)
- for(uword t=0; t < n_threads; ++t)
- {
- uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
-
- const uword start_index = boundaries.at(0,t);
- const uword end_index = boundaries.at(1,t);
-
- for(uword i=start_index; i <= end_index; ++i)
- {
- const eT* X_colptr = X.colptr(i);
-
- eT min_dist = Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g<N_gaus; ++g)
- {
- const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
-
- if(dist < min_dist) { min_dist = dist; best_g = g; }
- }
-
- eT* t_acc_mean = t_acc_means(t).colptr(best_g);
- eT* t_acc_dcov = t_acc_dcovs(t).colptr(best_g);
-
- for(uword d=0; d<N_dims; ++d)
- {
- const eT x_d = X_colptr[d];
-
- t_acc_mean[d] += x_d;
- t_acc_dcov[d] += x_d*x_d;
- }
-
- t_acc_hefts_mem[best_g]++;
- }
- }
-
- // reduction
- acc_means = t_acc_means(0);
- acc_dcovs = t_acc_dcovs(0);
- acc_hefts = t_acc_hefts(0);
-
- for(uword t=1; t < n_threads; ++t)
- {
- acc_means += t_acc_means(t);
- acc_dcovs += t_acc_dcovs(t);
- acc_hefts += t_acc_hefts(t);
- }
- }
- #else
- {
- for(uword i=0; i<X_n_cols; ++i)
- {
- const eT* X_colptr = X.colptr(i);
-
- eT min_dist = Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g<N_gaus; ++g)
- {
- const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
-
- if(dist < min_dist) { min_dist = dist; best_g = g; }
- }
-
- eT* acc_mean = acc_means.colptr(best_g);
- eT* acc_dcov = acc_dcovs.colptr(best_g);
-
- for(uword d=0; d<N_dims; ++d)
- {
- const eT x_d = X_colptr[d];
-
- acc_mean[d] += x_d;
- acc_dcov[d] += x_d*x_d;
- }
-
- acc_hefts_mem[best_g]++;
- }
- }
- #endif
-
- eT* hefts_mem = access::rw(hefts).memptr();
-
- for(uword g=0; g<N_gaus; ++g)
- {
- const eT* acc_mean = acc_means.colptr(g);
- const eT* acc_dcov = acc_dcovs.colptr(g);
- const uword acc_heft = acc_hefts_mem[g];
-
- eT* mean = access::rw(means).colptr(g);
-
- Mat<eT>& fcov = access::rw(fcovs).slice(g);
- fcov.zeros();
-
- for(uword d=0; d<N_dims; ++d)
- {
- const eT tmp = acc_mean[d] / eT(acc_heft);
-
- mean[d] = (acc_heft >= 1) ? tmp : eT(0);
- fcov.at(d,d) = (acc_heft >= 2) ? eT((acc_dcov[d] / eT(acc_heft)) - (tmp*tmp)) : eT(var_floor);
- }
-
- hefts_mem[g] = eT(acc_heft) / eT(X_n_cols);
- }
-
- em_fix_params(var_floor);
- }
- //! multi-threaded implementation of k-means, inspired by MapReduce
- template<typename eT>
- template<uword dist_id>
- inline
- bool
- gmm_full<eT>::km_iterate(const Mat<eT>& X, const uword max_iter, const bool verbose)
- {
- arma_extra_debug_sigprint();
-
- if(verbose)
- {
- get_cout_stream().unsetf(ios::showbase);
- get_cout_stream().unsetf(ios::uppercase);
- get_cout_stream().unsetf(ios::showpos);
- get_cout_stream().unsetf(ios::scientific);
-
- get_cout_stream().setf(ios::right);
- get_cout_stream().setf(ios::fixed);
- }
-
- const uword X_n_cols = X.n_cols;
-
- if(X_n_cols == 0) { return true; }
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- const eT* mah_aux_mem = mah_aux.memptr();
-
- Mat<eT> acc_means(N_dims, N_gaus, fill::zeros);
- Row<uword> acc_hefts(N_gaus, fill::zeros);
- Row<uword> last_indx(N_gaus, fill::zeros);
-
- Mat<eT> new_means = means;
- Mat<eT> old_means = means;
-
- running_mean_scalar<eT> rs_delta;
-
- #if defined(ARMA_USE_OPENMP)
- const umat boundaries = internal_gen_boundaries(X_n_cols);
- const uword n_threads = boundaries.n_cols;
-
- field< Mat<eT> > t_acc_means(n_threads);
- field< Row<uword> > t_acc_hefts(n_threads);
- field< Row<uword> > t_last_indx(n_threads);
- #else
- const uword n_threads = 1;
- #endif
-
- if(verbose) { get_cout_stream() << "gmm_full::learn(): k-means: n_threads: " << n_threads << '\n'; get_cout_stream().flush(); }
-
- for(uword iter=1; iter <= max_iter; ++iter)
- {
- #if defined(ARMA_USE_OPENMP)
- {
- for(uword t=0; t < n_threads; ++t)
- {
- t_acc_means(t).zeros(N_dims, N_gaus);
- t_acc_hefts(t).zeros(N_gaus);
- t_last_indx(t).zeros(N_gaus);
- }
-
- #pragma omp parallel for schedule(static)
- for(uword t=0; t < n_threads; ++t)
- {
- Mat<eT>& t_acc_means_t = t_acc_means(t);
- uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
- uword* t_last_indx_mem = t_last_indx(t).memptr();
-
- const uword start_index = boundaries.at(0,t);
- const uword end_index = boundaries.at(1,t);
-
- for(uword i=start_index; i <= end_index; ++i)
- {
- const eT* X_colptr = X.colptr(i);
-
- eT min_dist = Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g<N_gaus; ++g)
- {
- const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
-
- if(dist < min_dist) { min_dist = dist; best_g = g; }
- }
-
- eT* t_acc_mean = t_acc_means_t.colptr(best_g);
-
- for(uword d=0; d<N_dims; ++d) { t_acc_mean[d] += X_colptr[d]; }
-
- t_acc_hefts_mem[best_g]++;
- t_last_indx_mem[best_g] = i;
- }
- }
-
- // reduction
-
- acc_means = t_acc_means(0);
- acc_hefts = t_acc_hefts(0);
-
- for(uword t=1; t < n_threads; ++t)
- {
- acc_means += t_acc_means(t);
- acc_hefts += t_acc_hefts(t);
- }
-
- for(uword g=0; g < N_gaus; ++g)
- for(uword t=0; t < n_threads; ++t)
- {
- if( t_acc_hefts(t)(g) >= 1 ) { last_indx(g) = t_last_indx(t)(g); }
- }
- }
- #else
- {
- uword* acc_hefts_mem = acc_hefts.memptr();
- uword* last_indx_mem = last_indx.memptr();
-
- for(uword i=0; i < X_n_cols; ++i)
- {
- const eT* X_colptr = X.colptr(i);
-
- eT min_dist = Datum<eT>::inf;
- uword best_g = 0;
-
- for(uword g=0; g<N_gaus; ++g)
- {
- const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
-
- if(dist < min_dist) { min_dist = dist; best_g = g; }
- }
-
- eT* acc_mean = acc_means.colptr(best_g);
-
- for(uword d=0; d<N_dims; ++d) { acc_mean[d] += X_colptr[d]; }
-
- acc_hefts_mem[best_g]++;
- last_indx_mem[best_g] = i;
- }
- }
- #endif
-
- // generate new means
-
- uword* acc_hefts_mem = acc_hefts.memptr();
-
- for(uword g=0; g < N_gaus; ++g)
- {
- const eT* acc_mean = acc_means.colptr(g);
- const uword acc_heft = acc_hefts_mem[g];
-
- eT* new_mean = access::rw(new_means).colptr(g);
-
- for(uword d=0; d<N_dims; ++d)
- {
- new_mean[d] = (acc_heft >= 1) ? (acc_mean[d] / eT(acc_heft)) : eT(0);
- }
- }
-
-
- // heuristics to resurrect dead means
-
- const uvec dead_gs = find(acc_hefts == uword(0));
-
- if(dead_gs.n_elem > 0)
- {
- if(verbose) { get_cout_stream() << "gmm_full::learn(): k-means: recovering from dead means\n"; get_cout_stream().flush(); }
-
- uword* last_indx_mem = last_indx.memptr();
-
- const uvec live_gs = sort( find(acc_hefts >= uword(2)), "descend" );
-
- if(live_gs.n_elem == 0) { return false; }
-
- uword live_gs_count = 0;
-
- for(uword dead_gs_count = 0; dead_gs_count < dead_gs.n_elem; ++dead_gs_count)
- {
- const uword dead_g_id = dead_gs(dead_gs_count);
-
- uword proposed_i = 0;
-
- if(live_gs_count < live_gs.n_elem)
- {
- const uword live_g_id = live_gs(live_gs_count); ++live_gs_count;
-
- if(live_g_id == dead_g_id) { return false; }
-
- // recover by using a sample from a known good mean
- proposed_i = last_indx_mem[live_g_id];
- }
- else
- {
- // recover by using a randomly seleced sample (last resort)
- proposed_i = as_scalar(randi<uvec>(1, distr_param(0,X_n_cols-1)));
- }
-
- if(proposed_i >= X_n_cols) { return false; }
-
- new_means.col(dead_g_id) = X.col(proposed_i);
- }
- }
- rs_delta.reset();
-
- for(uword g=0; g < N_gaus; ++g)
- {
- rs_delta( distance<eT,dist_id>::eval(N_dims, old_means.colptr(g), new_means.colptr(g), mah_aux_mem) );
- }
-
- if(verbose)
- {
- get_cout_stream() << "gmm_full::learn(): k-means: iteration: ";
- get_cout_stream().unsetf(ios::scientific);
- get_cout_stream().setf(ios::fixed);
- get_cout_stream().width(std::streamsize(4));
- get_cout_stream() << iter;
- get_cout_stream() << " delta: ";
- get_cout_stream().unsetf(ios::fixed);
- //get_cout_stream().setf(ios::scientific);
- get_cout_stream() << rs_delta.mean() << '\n';
- get_cout_stream().flush();
- }
-
- arma::swap(old_means, new_means);
-
- if(rs_delta.mean() <= Datum<eT>::eps) { break; }
- }
-
- access::rw(means) = old_means;
-
- if(means.is_finite() == false) { return false; }
-
- return true;
- }
- //! multi-threaded implementation of Expectation-Maximisation, inspired by MapReduce
- template<typename eT>
- inline
- bool
- gmm_full<eT>::em_iterate(const Mat<eT>& X, const uword max_iter, const eT var_floor, const bool verbose)
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- if(verbose)
- {
- get_cout_stream().unsetf(ios::showbase);
- get_cout_stream().unsetf(ios::uppercase);
- get_cout_stream().unsetf(ios::showpos);
- get_cout_stream().unsetf(ios::scientific);
-
- get_cout_stream().setf(ios::right);
- get_cout_stream().setf(ios::fixed);
- }
-
- const umat boundaries = internal_gen_boundaries(X.n_cols);
-
- const uword n_threads = boundaries.n_cols;
-
- field< Mat<eT> > t_acc_means(n_threads);
- field< Cube<eT> > t_acc_fcovs(n_threads);
-
- field< Col<eT> > t_acc_norm_lhoods(n_threads);
- field< Col<eT> > t_gaus_log_lhoods(n_threads);
-
- Col<eT> t_progress_log_lhood(n_threads);
-
- for(uword t=0; t<n_threads; t++)
- {
- t_acc_means[t].set_size(N_dims, N_gaus);
- t_acc_fcovs[t].set_size(N_dims, N_dims, N_gaus);
-
- t_acc_norm_lhoods[t].set_size(N_gaus);
- t_gaus_log_lhoods[t].set_size(N_gaus);
- }
-
-
- if(verbose)
- {
- get_cout_stream() << "gmm_full::learn(): EM: n_threads: " << n_threads << '\n';
- }
-
- eT old_avg_log_p = -Datum<eT>::inf;
-
- const bool calc_chol = false;
-
- for(uword iter=1; iter <= max_iter; ++iter)
- {
- init_constants(calc_chol);
-
- em_update_params(X, boundaries, t_acc_means, t_acc_fcovs, t_acc_norm_lhoods, t_gaus_log_lhoods, t_progress_log_lhood, var_floor);
-
- em_fix_params(var_floor);
-
- const eT new_avg_log_p = accu(t_progress_log_lhood) / eT(t_progress_log_lhood.n_elem);
-
- if(verbose)
- {
- get_cout_stream() << "gmm_full::learn(): EM: iteration: ";
- get_cout_stream().unsetf(ios::scientific);
- get_cout_stream().setf(ios::fixed);
- get_cout_stream().width(std::streamsize(4));
- get_cout_stream() << iter;
- get_cout_stream() << " avg_log_p: ";
- get_cout_stream().unsetf(ios::fixed);
- //get_cout_stream().setf(ios::scientific);
- get_cout_stream() << new_avg_log_p << '\n';
- get_cout_stream().flush();
- }
-
- if(arma_isfinite(new_avg_log_p) == false) { return false; }
-
- if(std::abs(old_avg_log_p - new_avg_log_p) <= Datum<eT>::eps) { break; }
-
-
- old_avg_log_p = new_avg_log_p;
- }
-
-
- for(uword g=0; g < N_gaus; ++g)
- {
- const Mat<eT>& fcov = fcovs.slice(g);
-
- if(any(vectorise(fcov.diag()) <= eT(0))) { return false; }
- }
-
- if(means.is_finite() == false) { return false; }
- if(fcovs.is_finite() == false) { return false; }
- if(hefts.is_finite() == false) { return false; }
-
- return true;
- }
- template<typename eT>
- inline
- void
- gmm_full<eT>::em_update_params
- (
- const Mat<eT>& X,
- const umat& boundaries,
- field< Mat<eT> >& t_acc_means,
- field< Cube<eT> >& t_acc_fcovs,
- field< Col<eT> >& t_acc_norm_lhoods,
- field< Col<eT> >& t_gaus_log_lhoods,
- Col<eT>& t_progress_log_lhood,
- const eT var_floor
- )
- {
- arma_extra_debug_sigprint();
-
- const uword n_threads = boundaries.n_cols;
-
-
- // em_generate_acc() is the "map" operation, which produces partial accumulators for means, diagonal covariances and hefts
-
- #if defined(ARMA_USE_OPENMP)
- {
- #pragma omp parallel for schedule(static)
- for(uword t=0; t<n_threads; t++)
- {
- Mat<eT>& acc_means = t_acc_means[t];
- Cube<eT>& acc_fcovs = t_acc_fcovs[t];
- Col<eT>& acc_norm_lhoods = t_acc_norm_lhoods[t];
- Col<eT>& gaus_log_lhoods = t_gaus_log_lhoods[t];
- eT& progress_log_lhood = t_progress_log_lhood[t];
-
- em_generate_acc(X, boundaries.at(0,t), boundaries.at(1,t), acc_means, acc_fcovs, acc_norm_lhoods, gaus_log_lhoods, progress_log_lhood);
- }
- }
- #else
- {
- em_generate_acc(X, boundaries.at(0,0), boundaries.at(1,0), t_acc_means[0], t_acc_fcovs[0], t_acc_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]);
- }
- #endif
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- Mat<eT>& final_acc_means = t_acc_means[0];
- Cube<eT>& final_acc_fcovs = t_acc_fcovs[0];
-
- Col<eT>& final_acc_norm_lhoods = t_acc_norm_lhoods[0];
-
-
- // the "reduce" operation, which combines the partial accumulators produced by the separate threads
-
- for(uword t=1; t<n_threads; t++)
- {
- final_acc_means += t_acc_means[t];
- final_acc_fcovs += t_acc_fcovs[t];
-
- final_acc_norm_lhoods += t_acc_norm_lhoods[t];
- }
-
-
- eT* hefts_mem = access::rw(hefts).memptr();
-
- Mat<eT> mean_outer(N_dims, N_dims);
-
-
- //// update each component without sanity checking
- //for(uword g=0; g < N_gaus; ++g)
- // {
- // const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
- //
- // hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
- //
- // eT* mean_mem = access::rw(means).colptr(g);
- // eT* acc_mean_mem = final_acc_means.colptr(g);
- //
- // for(uword d=0; d < N_dims; ++d)
- // {
- // mean_mem[d] = acc_mean_mem[d] / acc_norm_lhood;
- // }
- //
- // const Col<eT> mean(mean_mem, N_dims, false, true);
- //
- // mean_outer = mean * mean.t();
- //
- // Mat<eT>& fcov = access::rw(fcovs).slice(g);
- // Mat<eT>& acc_fcov = final_acc_fcovs.slice(g);
- //
- // fcov = acc_fcov / acc_norm_lhood - mean_outer;
- // }
-
-
- // conditionally update each component; if only a subset of the hefts was updated, em_fix_params() will sanitise them
- for(uword g=0; g < N_gaus; ++g)
- {
- const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
-
- if(arma_isfinite(acc_norm_lhood) == false) { continue; }
-
- eT* acc_mean_mem = final_acc_means.colptr(g);
-
- for(uword d=0; d < N_dims; ++d)
- {
- acc_mean_mem[d] /= acc_norm_lhood;
- }
-
- const Col<eT> new_mean(acc_mean_mem, N_dims, false, true);
-
- mean_outer = new_mean * new_mean.t();
-
- Mat<eT>& acc_fcov = final_acc_fcovs.slice(g);
-
- acc_fcov /= acc_norm_lhood;
- acc_fcov -= mean_outer;
-
- for(uword d=0; d < N_dims; ++d)
- {
- eT& val = acc_fcov.at(d,d);
-
- if(val < var_floor) { val = var_floor; }
- }
-
- if(acc_fcov.is_finite() == false) { continue; }
-
- eT log_det_val = eT(0);
- eT log_det_sign = eT(0);
-
- log_det(log_det_val, log_det_sign, acc_fcov);
-
- const bool log_det_ok = ( (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) );
-
- const bool inv_ok = (log_det_ok) ? bool(auxlib::inv_sympd(mean_outer, acc_fcov)) : bool(false); // mean_outer is used as a junk matrix
-
- if(log_det_ok && inv_ok)
- {
- hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
-
- eT* mean_mem = access::rw(means).colptr(g);
-
- for(uword d=0; d < N_dims; ++d)
- {
- mean_mem[d] = acc_mean_mem[d];
- }
-
- Mat<eT>& fcov = access::rw(fcovs).slice(g);
-
- fcov = acc_fcov;
- }
- }
- }
- template<typename eT>
- inline
- void
- gmm_full<eT>::em_generate_acc
- (
- const Mat<eT>& X,
- const uword start_index,
- const uword end_index,
- Mat<eT>& acc_means,
- Cube<eT>& acc_fcovs,
- Col<eT>& acc_norm_lhoods,
- Col<eT>& gaus_log_lhoods,
- eT& progress_log_lhood
- )
- const
- {
- arma_extra_debug_sigprint();
-
- progress_log_lhood = eT(0);
-
- acc_means.zeros();
- acc_fcovs.zeros();
-
- acc_norm_lhoods.zeros();
- gaus_log_lhoods.zeros();
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- const eT* log_hefts_mem = log_hefts.memptr();
- eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr();
-
-
- for(uword i=start_index; i <= end_index; i++)
- {
- const eT* x = X.colptr(i);
-
- for(uword g=0; g < N_gaus; ++g)
- {
- gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[g];
- }
-
- eT log_lhood_sum = gaus_log_lhoods_mem[0];
-
- for(uword g=1; g < N_gaus; ++g)
- {
- log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]);
- }
-
- progress_log_lhood += log_lhood_sum;
-
- for(uword g=0; g < N_gaus; ++g)
- {
- const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum);
-
- acc_norm_lhoods[g] += norm_lhood;
-
- eT* acc_mean_mem = acc_means.colptr(g);
-
- for(uword d=0; d < N_dims; ++d)
- {
- acc_mean_mem[d] += x[d] * norm_lhood;
- }
-
- Mat<eT>& acc_fcov = access::rw(acc_fcovs).slice(g);
-
- // specialised version of acc_fcov += norm_lhood * (xx * xx.t());
-
- for(uword d=0; d < N_dims; ++d)
- {
- const uword dp1 = d+1;
-
- const eT xd = x[d];
-
- eT* acc_fcov_col_d = acc_fcov.colptr(d) + d;
- eT* acc_fcov_row_d = &(acc_fcov.at(d,dp1));
-
- (*acc_fcov_col_d) += norm_lhood * (xd * xd); acc_fcov_col_d++;
-
- for(uword e=dp1; e < N_dims; ++e)
- {
- const eT val = norm_lhood * (xd * x[e]);
-
- (*acc_fcov_col_d) += val; acc_fcov_col_d++;
- (*acc_fcov_row_d) += val; acc_fcov_row_d += N_dims;
- }
- }
- }
- }
-
- progress_log_lhood /= eT((end_index - start_index) + 1);
- }
- template<typename eT>
- inline
- void
- gmm_full<eT>::em_fix_params(const eT var_floor)
- {
- arma_extra_debug_sigprint();
-
- const uword N_dims = means.n_rows;
- const uword N_gaus = means.n_cols;
-
- const eT var_ceiling = std::numeric_limits<eT>::max();
-
- for(uword g=0; g < N_gaus; ++g)
- {
- Mat<eT>& fcov = access::rw(fcovs).slice(g);
-
- for(uword d=0; d < N_dims; ++d)
- {
- eT& var_val = fcov.at(d,d);
-
- if(var_val < var_floor ) { var_val = var_floor; }
- else if(var_val > var_ceiling) { var_val = var_ceiling; }
- else if(arma_isnan(var_val) ) { var_val = eT(1); }
- }
- }
-
-
- eT* hefts_mem = access::rw(hefts).memptr();
-
- for(uword g1=0; g1 < N_gaus; ++g1)
- {
- if(hefts_mem[g1] > eT(0))
- {
- const eT* means_colptr_g1 = means.colptr(g1);
-
- for(uword g2=(g1+1); g2 < N_gaus; ++g2)
- {
- if( (hefts_mem[g2] > eT(0)) && (std::abs(hefts_mem[g1] - hefts_mem[g2]) <= std::numeric_limits<eT>::epsilon()) )
- {
- const eT dist = distance<eT,1>::eval(N_dims, means_colptr_g1, means.colptr(g2), means_colptr_g1);
-
- if(dist == eT(0)) { hefts_mem[g2] = eT(0); }
- }
- }
- }
- }
-
- const eT heft_floor = std::numeric_limits<eT>::min();
- const eT heft_initial = eT(1) / eT(N_gaus);
-
- for(uword i=0; i < N_gaus; ++i)
- {
- eT& heft_val = hefts_mem[i];
-
- if(heft_val < heft_floor) { heft_val = heft_floor; }
- else if(heft_val > eT(1) ) { heft_val = eT(1); }
- else if(arma_isnan(heft_val) ) { heft_val = heft_initial; }
- }
-
- const eT heft_sum = accu(hefts);
-
- if((heft_sum < (eT(1) - Datum<eT>::eps)) || (heft_sum > (eT(1) + Datum<eT>::eps))) { access::rw(hefts) /= heft_sum; }
- }
- } // namespace gmm_priv
- //! @}
|