gmm_full_meat.hpp 66 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup gmm_full
  16. //! @{
  17. namespace gmm_priv
  18. {
  19. template<typename eT>
  20. inline
  21. gmm_full<eT>::~gmm_full()
  22. {
  23. arma_extra_debug_sigprint_this(this);
  24. arma_type_check(( (is_same_type<eT,float>::value == false) && (is_same_type<eT,double>::value == false) ));
  25. }
  26. template<typename eT>
  27. inline
  28. gmm_full<eT>::gmm_full()
  29. {
  30. arma_extra_debug_sigprint_this(this);
  31. }
  32. template<typename eT>
  33. inline
  34. gmm_full<eT>::gmm_full(const gmm_full<eT>& x)
  35. {
  36. arma_extra_debug_sigprint_this(this);
  37. init(x);
  38. }
  39. template<typename eT>
  40. inline
  41. gmm_full<eT>&
  42. gmm_full<eT>::operator=(const gmm_full<eT>& x)
  43. {
  44. arma_extra_debug_sigprint();
  45. init(x);
  46. return *this;
  47. }
  48. template<typename eT>
  49. inline
  50. gmm_full<eT>::gmm_full(const gmm_diag<eT>& x)
  51. {
  52. arma_extra_debug_sigprint_this(this);
  53. init(x);
  54. }
  55. template<typename eT>
  56. inline
  57. gmm_full<eT>&
  58. gmm_full<eT>::operator=(const gmm_diag<eT>& x)
  59. {
  60. arma_extra_debug_sigprint();
  61. init(x);
  62. return *this;
  63. }
  64. template<typename eT>
  65. inline
  66. gmm_full<eT>::gmm_full(const uword in_n_dims, const uword in_n_gaus)
  67. {
  68. arma_extra_debug_sigprint_this(this);
  69. init(in_n_dims, in_n_gaus);
  70. }
  71. template<typename eT>
  72. inline
  73. void
  74. gmm_full<eT>::reset()
  75. {
  76. arma_extra_debug_sigprint();
  77. init(0, 0);
  78. }
  79. template<typename eT>
  80. inline
  81. void
  82. gmm_full<eT>::reset(const uword in_n_dims, const uword in_n_gaus)
  83. {
  84. arma_extra_debug_sigprint();
  85. init(in_n_dims, in_n_gaus);
  86. }
  87. template<typename eT>
  88. template<typename T1, typename T2, typename T3>
  89. inline
  90. void
  91. gmm_full<eT>::set_params(const Base<eT,T1>& in_means_expr, const BaseCube<eT,T2>& in_fcovs_expr, const Base<eT,T3>& in_hefts_expr)
  92. {
  93. arma_extra_debug_sigprint();
  94. const unwrap <T1> tmp1(in_means_expr.get_ref());
  95. const unwrap_cube<T2> tmp2(in_fcovs_expr.get_ref());
  96. const unwrap <T3> tmp3(in_hefts_expr.get_ref());
  97. const Mat <eT>& in_means = tmp1.M;
  98. const Cube<eT>& in_fcovs = tmp2.M;
  99. const Mat <eT>& in_hefts = tmp3.M;
  100. arma_debug_check
  101. (
  102. (in_means.n_cols != in_fcovs.n_slices) || (in_means.n_rows != in_fcovs.n_rows) || (in_fcovs.n_rows != in_fcovs.n_cols) || (in_hefts.n_cols != in_means.n_cols) || (in_hefts.n_rows != 1),
  103. "gmm_full::set_params(): given parameters have inconsistent and/or wrong sizes"
  104. );
  105. arma_debug_check( (in_means.is_finite() == false), "gmm_full::set_params(): given means have non-finite values" );
  106. arma_debug_check( (in_fcovs.is_finite() == false), "gmm_full::set_params(): given fcovs have non-finite values" );
  107. arma_debug_check( (in_hefts.is_finite() == false), "gmm_full::set_params(): given hefts have non-finite values" );
  108. for(uword g=0; g < in_fcovs.n_slices; ++g)
  109. {
  110. arma_debug_check( (any(diagvec(in_fcovs.slice(g)) <= eT(0))), "gmm_full::set_params(): given fcovs have negative or zero values on diagonals" );
  111. }
  112. arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_full::set_params(): given hefts have negative values" );
  113. const eT s = accu(in_hefts);
  114. arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_full::set_params(): sum of given hefts is not 1" );
  115. access::rw(means) = in_means;
  116. access::rw(fcovs) = in_fcovs;
  117. access::rw(hefts) = in_hefts;
  118. init_constants();
  119. }
  120. template<typename eT>
  121. template<typename T1>
  122. inline
  123. void
  124. gmm_full<eT>::set_means(const Base<eT,T1>& in_means_expr)
  125. {
  126. arma_extra_debug_sigprint();
  127. const unwrap<T1> tmp(in_means_expr.get_ref());
  128. const Mat<eT>& in_means = tmp.M;
  129. arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_full::set_means(): given means have incompatible size" );
  130. arma_debug_check( (in_means.is_finite() == false), "gmm_full::set_means(): given means have non-finite values" );
  131. access::rw(means) = in_means;
  132. }
  133. template<typename eT>
  134. template<typename T1>
  135. inline
  136. void
  137. gmm_full<eT>::set_fcovs(const BaseCube<eT,T1>& in_fcovs_expr)
  138. {
  139. arma_extra_debug_sigprint();
  140. const unwrap_cube<T1> tmp(in_fcovs_expr.get_ref());
  141. const Cube<eT>& in_fcovs = tmp.M;
  142. arma_debug_check( (arma::size(in_fcovs) != arma::size(fcovs)), "gmm_full::set_fcovs(): given fcovs have incompatible size" );
  143. arma_debug_check( (in_fcovs.is_finite() == false), "gmm_full::set_fcovs(): given fcovs have non-finite values" );
  144. for(uword i=0; i < in_fcovs.n_slices; ++i)
  145. {
  146. arma_debug_check( (any(diagvec(in_fcovs.slice(i)) <= eT(0))), "gmm_full::set_fcovs(): given fcovs have negative or zero values on diagonals" );
  147. }
  148. access::rw(fcovs) = in_fcovs;
  149. init_constants();
  150. }
  151. template<typename eT>
  152. template<typename T1>
  153. inline
  154. void
  155. gmm_full<eT>::set_hefts(const Base<eT,T1>& in_hefts_expr)
  156. {
  157. arma_extra_debug_sigprint();
  158. const unwrap<T1> tmp(in_hefts_expr.get_ref());
  159. const Mat<eT>& in_hefts = tmp.M;
  160. arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_full::set_hefts(): given hefts have incompatible size" );
  161. arma_debug_check( (in_hefts.is_finite() == false), "gmm_full::set_hefts(): given hefts have non-finite values" );
  162. arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_full::set_hefts(): given hefts have negative values" );
  163. const eT s = accu(in_hefts);
  164. arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_full::set_hefts(): sum of given hefts is not 1" );
  165. // make sure all hefts are positive and non-zero
  166. const eT* in_hefts_mem = in_hefts.memptr();
  167. eT* hefts_mem = access::rw(hefts).memptr();
  168. for(uword i=0; i < hefts.n_elem; ++i)
  169. {
  170. hefts_mem[i] = (std::max)( in_hefts_mem[i], std::numeric_limits<eT>::min() );
  171. }
  172. access::rw(hefts) /= accu(hefts);
  173. log_hefts = log(hefts);
  174. }
  175. template<typename eT>
  176. inline
  177. uword
  178. gmm_full<eT>::n_dims() const
  179. {
  180. return means.n_rows;
  181. }
  182. template<typename eT>
  183. inline
  184. uword
  185. gmm_full<eT>::n_gaus() const
  186. {
  187. return means.n_cols;
  188. }
  189. template<typename eT>
  190. inline
  191. bool
  192. gmm_full<eT>::load(const std::string name)
  193. {
  194. arma_extra_debug_sigprint();
  195. field< Mat<eT> > storage;
  196. bool status = storage.load(name, arma_binary);
  197. if( (status == false) || (storage.n_elem < 2) )
  198. {
  199. reset();
  200. arma_debug_warn("gmm_full::load(): problem with loading or incompatible format");
  201. return false;
  202. }
  203. uword count = 0;
  204. const Mat<eT>& storage_means = storage(count); ++count;
  205. const Mat<eT>& storage_hefts = storage(count); ++count;
  206. const uword N_dims = storage_means.n_rows;
  207. const uword N_gaus = storage_means.n_cols;
  208. if( (storage.n_elem != (N_gaus + 2)) || (storage_hefts.n_rows != 1) || (storage_hefts.n_cols != N_gaus) )
  209. {
  210. reset();
  211. arma_debug_warn("gmm_full::load(): incompatible format");
  212. return false;
  213. }
  214. reset(N_dims, N_gaus);
  215. access::rw(means) = storage_means;
  216. access::rw(hefts) = storage_hefts;
  217. for(uword g=0; g < N_gaus; ++g)
  218. {
  219. const Mat<eT>& storage_fcov = storage(count); ++count;
  220. if( (storage_fcov.n_rows != N_dims) || (storage_fcov.n_cols != N_dims) )
  221. {
  222. reset();
  223. arma_debug_warn("gmm_full::load(): incompatible format");
  224. return false;
  225. }
  226. access::rw(fcovs).slice(g) = storage_fcov;
  227. }
  228. init_constants();
  229. return true;
  230. }
  231. template<typename eT>
  232. inline
  233. bool
  234. gmm_full<eT>::save(const std::string name) const
  235. {
  236. arma_extra_debug_sigprint();
  237. const uword N_gaus = means.n_cols;
  238. field< Mat<eT> > storage(2 + N_gaus);
  239. uword count = 0;
  240. storage(count) = means; ++count;
  241. storage(count) = hefts; ++count;
  242. for(uword g=0; g < N_gaus; ++g)
  243. {
  244. storage(count) = fcovs.slice(g); ++count;
  245. }
  246. const bool status = storage.save(name, arma_binary);
  247. return status;
  248. }
  249. template<typename eT>
  250. inline
  251. Col<eT>
  252. gmm_full<eT>::generate() const
  253. {
  254. arma_extra_debug_sigprint();
  255. const uword N_dims = means.n_rows;
  256. const uword N_gaus = means.n_cols;
  257. Col<eT> out( (N_gaus > 0) ? N_dims : uword(0) );
  258. Col<eT> tmp( (N_gaus > 0) ? N_dims : uword(0), fill::randn );
  259. if(N_gaus > 0)
  260. {
  261. const double val = randu<double>();
  262. double csum = double(0);
  263. uword gaus_id = 0;
  264. for(uword j=0; j < N_gaus; ++j)
  265. {
  266. csum += hefts[j];
  267. if(val <= csum) { gaus_id = j; break; }
  268. }
  269. out = chol_fcovs.slice(gaus_id) * tmp;
  270. out += means.col(gaus_id);
  271. }
  272. return out;
  273. }
  274. template<typename eT>
  275. inline
  276. Mat<eT>
  277. gmm_full<eT>::generate(const uword N_vec) const
  278. {
  279. arma_extra_debug_sigprint();
  280. const uword N_dims = means.n_rows;
  281. const uword N_gaus = means.n_cols;
  282. Mat<eT> out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec );
  283. Mat<eT> tmp( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn );
  284. if(N_gaus > 0)
  285. {
  286. const eT* hefts_mem = hefts.memptr();
  287. for(uword i=0; i < N_vec; ++i)
  288. {
  289. const double val = randu<double>();
  290. double csum = double(0);
  291. uword gaus_id = 0;
  292. for(uword j=0; j < N_gaus; ++j)
  293. {
  294. csum += hefts_mem[j];
  295. if(val <= csum) { gaus_id = j; break; }
  296. }
  297. Col<eT> out_vec(out.colptr(i), N_dims, false, true);
  298. Col<eT> tmp_vec(tmp.colptr(i), N_dims, false, true);
  299. out_vec = chol_fcovs.slice(gaus_id) * tmp_vec;
  300. out_vec += means.col(gaus_id);
  301. }
  302. }
  303. return out;
  304. }
  305. template<typename eT>
  306. template<typename T1>
  307. inline
  308. eT
  309. gmm_full<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
  310. {
  311. arma_extra_debug_sigprint();
  312. arma_ignore(junk1);
  313. arma_ignore(junk2);
  314. const uword N_dims = means.n_rows;
  315. const quasi_unwrap<T1> U(expr);
  316. arma_debug_check( (U.M.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
  317. return internal_scalar_log_p( U.M.memptr() );
  318. }
  319. template<typename eT>
  320. template<typename T1>
  321. inline
  322. eT
  323. gmm_full<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
  324. {
  325. arma_extra_debug_sigprint();
  326. arma_ignore(junk2);
  327. const uword N_dims = means.n_rows;
  328. const quasi_unwrap<T1> U(expr);
  329. arma_debug_check( (U.M.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
  330. arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::log_p(): specified gaussian is out of range" );
  331. return internal_scalar_log_p( U.M.memptr(), gaus_id );
  332. }
  333. template<typename eT>
  334. template<typename T1>
  335. inline
  336. Row<eT>
  337. gmm_full<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
  338. {
  339. arma_extra_debug_sigprint();
  340. arma_ignore(junk1);
  341. arma_ignore(junk2);
  342. const quasi_unwrap<T1> tmp(expr);
  343. const Mat<eT>& X = tmp.M;
  344. return internal_vec_log_p(X);
  345. }
  346. template<typename eT>
  347. template<typename T1>
  348. inline
  349. Row<eT>
  350. gmm_full<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
  351. {
  352. arma_extra_debug_sigprint();
  353. arma_ignore(junk2);
  354. const quasi_unwrap<T1> tmp(expr);
  355. const Mat<eT>& X = tmp.M;
  356. return internal_vec_log_p(X, gaus_id);
  357. }
  358. template<typename eT>
  359. template<typename T1>
  360. inline
  361. eT
  362. gmm_full<eT>::sum_log_p(const Base<eT,T1>& expr) const
  363. {
  364. arma_extra_debug_sigprint();
  365. const quasi_unwrap<T1> tmp(expr.get_ref());
  366. const Mat<eT>& X = tmp.M;
  367. return internal_sum_log_p(X);
  368. }
  369. template<typename eT>
  370. template<typename T1>
  371. inline
  372. eT
  373. gmm_full<eT>::sum_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
  374. {
  375. arma_extra_debug_sigprint();
  376. const quasi_unwrap<T1> tmp(expr.get_ref());
  377. const Mat<eT>& X = tmp.M;
  378. return internal_sum_log_p(X, gaus_id);
  379. }
  380. template<typename eT>
  381. template<typename T1>
  382. inline
  383. eT
  384. gmm_full<eT>::avg_log_p(const Base<eT,T1>& expr) const
  385. {
  386. arma_extra_debug_sigprint();
  387. const quasi_unwrap<T1> tmp(expr.get_ref());
  388. const Mat<eT>& X = tmp.M;
  389. return internal_avg_log_p(X);
  390. }
  391. template<typename eT>
  392. template<typename T1>
  393. inline
  394. eT
  395. gmm_full<eT>::avg_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
  396. {
  397. arma_extra_debug_sigprint();
  398. const quasi_unwrap<T1> tmp(expr.get_ref());
  399. const Mat<eT>& X = tmp.M;
  400. return internal_avg_log_p(X, gaus_id);
  401. }
  402. template<typename eT>
  403. template<typename T1>
  404. inline
  405. uword
  406. gmm_full<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk) const
  407. {
  408. arma_extra_debug_sigprint();
  409. arma_ignore(junk);
  410. const quasi_unwrap<T1> tmp(expr);
  411. const Mat<eT>& X = tmp.M;
  412. return internal_scalar_assign(X, dist);
  413. }
  414. template<typename eT>
  415. template<typename T1>
  416. inline
  417. urowvec
  418. gmm_full<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk) const
  419. {
  420. arma_extra_debug_sigprint();
  421. arma_ignore(junk);
  422. urowvec out;
  423. const quasi_unwrap<T1> tmp(expr);
  424. const Mat<eT>& X = tmp.M;
  425. internal_vec_assign(out, X, dist);
  426. return out;
  427. }
  428. template<typename eT>
  429. template<typename T1>
  430. inline
  431. urowvec
  432. gmm_full<eT>::raw_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
  433. {
  434. arma_extra_debug_sigprint();
  435. const unwrap<T1> tmp(expr.get_ref());
  436. const Mat<eT>& X = tmp.M;
  437. arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::raw_hist(): incompatible dimensions" );
  438. arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_full::raw_hist(): unsupported distance mode" );
  439. urowvec hist;
  440. internal_raw_hist(hist, X, dist_mode);
  441. return hist;
  442. }
  443. template<typename eT>
  444. template<typename T1>
  445. inline
  446. Row<eT>
  447. gmm_full<eT>::norm_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
  448. {
  449. arma_extra_debug_sigprint();
  450. const unwrap<T1> tmp(expr.get_ref());
  451. const Mat<eT>& X = tmp.M;
  452. arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::norm_hist(): incompatible dimensions" );
  453. arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_full::norm_hist(): unsupported distance mode" );
  454. urowvec hist;
  455. internal_raw_hist(hist, X, dist_mode);
  456. const uword hist_n_elem = hist.n_elem;
  457. const uword* hist_mem = hist.memptr();
  458. eT acc = eT(0);
  459. for(uword i=0; i<hist_n_elem; ++i) { acc += eT(hist_mem[i]); }
  460. if(acc == eT(0)) { acc = eT(1); }
  461. Row<eT> out(hist_n_elem);
  462. eT* out_mem = out.memptr();
  463. for(uword i=0; i<hist_n_elem; ++i) { out_mem[i] = eT(hist_mem[i]) / acc; }
  464. return out;
  465. }
  466. template<typename eT>
  467. template<typename T1>
  468. inline
  469. bool
  470. gmm_full<eT>::learn
  471. (
  472. const Base<eT,T1>& data,
  473. const uword N_gaus,
  474. const gmm_dist_mode& dist_mode,
  475. const gmm_seed_mode& seed_mode,
  476. const uword km_iter,
  477. const uword em_iter,
  478. const eT var_floor,
  479. const bool print_mode
  480. )
  481. {
  482. arma_extra_debug_sigprint();
  483. const bool dist_mode_ok = (dist_mode == eucl_dist) || (dist_mode == maha_dist);
  484. const bool seed_mode_ok = \
  485. (seed_mode == keep_existing)
  486. || (seed_mode == static_subset)
  487. || (seed_mode == static_spread)
  488. || (seed_mode == random_subset)
  489. || (seed_mode == random_spread);
  490. arma_debug_check( (dist_mode_ok == false), "gmm_full::learn(): dist_mode must be eucl_dist or maha_dist" );
  491. arma_debug_check( (seed_mode_ok == false), "gmm_full::learn(): unknown seed_mode" );
  492. arma_debug_check( (var_floor < eT(0) ), "gmm_full::learn(): variance floor is negative" );
  493. const unwrap<T1> tmp_X(data.get_ref());
  494. const Mat<eT>& X = tmp_X.M;
  495. if(X.is_empty() ) { arma_debug_warn("gmm_full::learn(): given matrix is empty" ); return false; }
  496. if(X.is_finite() == false) { arma_debug_warn("gmm_full::learn(): given matrix has non-finite values"); return false; }
  497. if(N_gaus == 0) { reset(); return true; }
  498. if(dist_mode == maha_dist)
  499. {
  500. mah_aux = var(X,1,1);
  501. const uword mah_aux_n_elem = mah_aux.n_elem;
  502. eT* mah_aux_mem = mah_aux.memptr();
  503. for(uword i=0; i < mah_aux_n_elem; ++i)
  504. {
  505. const eT val = mah_aux_mem[i];
  506. mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1);
  507. }
  508. }
  509. // copy current model, in case of failure by k-means and/or EM
  510. const gmm_full<eT> orig = (*this);
  511. // initial means
  512. if(seed_mode == keep_existing)
  513. {
  514. if(means.is_empty() ) { arma_debug_warn("gmm_full::learn(): no existing means" ); return false; }
  515. if(X.n_rows != means.n_rows) { arma_debug_warn("gmm_full::learn(): dimensionality mismatch"); return false; }
  516. // TODO: also check for number of vectors?
  517. }
  518. else
  519. {
  520. if(X.n_cols < N_gaus) { arma_debug_warn("gmm_full::learn(): number of vectors is less than number of gaussians"); return false; }
  521. reset(X.n_rows, N_gaus);
  522. if(print_mode) { get_cout_stream() << "gmm_full::learn(): generating initial means\n"; get_cout_stream().flush(); }
  523. if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mode); }
  524. else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mode); }
  525. }
  526. // k-means
  527. if(km_iter > 0)
  528. {
  529. const arma_ostream_state stream_state(get_cout_stream());
  530. bool status = false;
  531. if(dist_mode == eucl_dist) { status = km_iterate<1>(X, km_iter, print_mode); }
  532. else if(dist_mode == maha_dist) { status = km_iterate<2>(X, km_iter, print_mode); }
  533. stream_state.restore(get_cout_stream());
  534. if(status == false) { arma_debug_warn("gmm_full::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; }
  535. }
  536. // initial fcovs
  537. const eT var_floor_actual = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_limits<eT>::min();
  538. if(seed_mode != keep_existing)
  539. {
  540. if(print_mode) { get_cout_stream() << "gmm_full::learn(): generating initial covariances\n"; get_cout_stream().flush(); }
  541. if(dist_mode == eucl_dist) { generate_initial_params<1>(X, var_floor_actual); }
  542. else if(dist_mode == maha_dist) { generate_initial_params<2>(X, var_floor_actual); }
  543. }
  544. // EM algorithm
  545. if(em_iter > 0)
  546. {
  547. const arma_ostream_state stream_state(get_cout_stream());
  548. const bool status = em_iterate(X, em_iter, var_floor_actual, print_mode);
  549. stream_state.restore(get_cout_stream());
  550. if(status == false) { arma_debug_warn("gmm_full::learn(): EM algorithm failed"); init(orig); return false; }
  551. }
  552. mah_aux.reset();
  553. init_constants();
  554. return true;
  555. }
  556. //
  557. //
  558. //
  559. template<typename eT>
  560. inline
  561. void
  562. gmm_full<eT>::init(const gmm_full<eT>& x)
  563. {
  564. arma_extra_debug_sigprint();
  565. gmm_full<eT>& t = *this;
  566. if(&t != &x)
  567. {
  568. access::rw(t.means) = x.means;
  569. access::rw(t.fcovs) = x.fcovs;
  570. access::rw(t.hefts) = x.hefts;
  571. init_constants();
  572. }
  573. }
  574. template<typename eT>
  575. inline
  576. void
  577. gmm_full<eT>::init(const gmm_diag<eT>& x)
  578. {
  579. arma_extra_debug_sigprint();
  580. access::rw(hefts) = x.hefts;
  581. access::rw(means) = x.means;
  582. const uword N_dims = x.means.n_rows;
  583. const uword N_gaus = x.means.n_cols;
  584. access::rw(fcovs).zeros(N_dims,N_dims,N_gaus);
  585. for(uword g=0; g < N_gaus; ++g)
  586. {
  587. Mat<eT>& fcov = access::rw(fcovs).slice(g);
  588. const eT* dcov_mem = x.dcovs.colptr(g);
  589. for(uword d=0; d < N_dims; ++d)
  590. {
  591. fcov.at(d,d) = dcov_mem[d];
  592. }
  593. }
  594. init_constants();
  595. }
  596. template<typename eT>
  597. inline
  598. void
  599. gmm_full<eT>::init(const uword in_n_dims, const uword in_n_gaus)
  600. {
  601. arma_extra_debug_sigprint();
  602. access::rw(means).zeros(in_n_dims, in_n_gaus);
  603. access::rw(fcovs).zeros(in_n_dims, in_n_dims, in_n_gaus);
  604. for(uword g=0; g < in_n_gaus; ++g)
  605. {
  606. access::rw(fcovs).slice(g).diag().ones();
  607. }
  608. access::rw(hefts).set_size(in_n_gaus);
  609. access::rw(hefts).fill(eT(1) / eT(in_n_gaus));
  610. init_constants();
  611. }
  612. template<typename eT>
  613. inline
  614. void
  615. gmm_full<eT>::init_constants(const bool calc_chol)
  616. {
  617. arma_extra_debug_sigprint();
  618. const uword N_dims = means.n_rows;
  619. const uword N_gaus = means.n_cols;
  620. const eT tmp = (eT(N_dims)/eT(2)) * std::log(eT(2) * Datum<eT>::pi);
  621. //
  622. inv_fcovs.copy_size(fcovs);
  623. log_det_etc.set_size(N_gaus);
  624. Mat<eT> tmp_inv;
  625. for(uword g=0; g < N_gaus; ++g)
  626. {
  627. const Mat<eT>& fcov = fcovs.slice(g);
  628. Mat<eT>& inv_fcov = inv_fcovs.slice(g);
  629. //const bool inv_ok = auxlib::inv(tmp_inv, fcov);
  630. const bool inv_ok = auxlib::inv_sympd(tmp_inv, fcov);
  631. eT log_det_val = eT(0);
  632. eT log_det_sign = eT(0);
  633. log_det(log_det_val, log_det_sign, fcov);
  634. const bool log_det_ok = ( (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) );
  635. if(inv_ok && log_det_ok)
  636. {
  637. inv_fcov = tmp_inv;
  638. }
  639. else
  640. {
  641. // last resort: treat the covariance matrix as diagonal
  642. inv_fcov.zeros();
  643. log_det_val = eT(0);
  644. for(uword d=0; d < N_dims; ++d)
  645. {
  646. const eT sanitised_val = (std::max)( eT(fcov.at(d,d)), eT(std::numeric_limits<eT>::min()) );
  647. inv_fcov.at(d,d) = eT(1) / sanitised_val;
  648. log_det_val += std::log(sanitised_val);
  649. }
  650. }
  651. log_det_etc[g] = eT(-1) * ( tmp + eT(0.5) * log_det_val );
  652. }
  653. //
  654. eT* hefts_mem = access::rw(hefts).memptr();
  655. for(uword g=0; g < N_gaus; ++g)
  656. {
  657. hefts_mem[g] = (std::max)( hefts_mem[g], std::numeric_limits<eT>::min() );
  658. }
  659. log_hefts = log(hefts);
  660. if(calc_chol)
  661. {
  662. chol_fcovs.copy_size(fcovs);
  663. Mat<eT> tmp_chol;
  664. for(uword g=0; g < N_gaus; ++g)
  665. {
  666. const Mat<eT>& fcov = fcovs.slice(g);
  667. Mat<eT>& chol_fcov = chol_fcovs.slice(g);
  668. const uword chol_layout = 1; // indicates "lower"
  669. const bool chol_ok = op_chol::apply_direct(tmp_chol, fcov, chol_layout);
  670. if(chol_ok)
  671. {
  672. chol_fcov = tmp_chol;
  673. }
  674. else
  675. {
  676. // last resort: treat the covariance matrix as diagonal
  677. chol_fcov.zeros();
  678. for(uword d=0; d < N_dims; ++d)
  679. {
  680. const eT sanitised_val = (std::max)( eT(fcov.at(d,d)), eT(std::numeric_limits<eT>::min()) );
  681. chol_fcov.at(d,d) = std::sqrt(sanitised_val);
  682. }
  683. }
  684. }
  685. }
  686. }
  687. template<typename eT>
  688. inline
  689. umat
  690. gmm_full<eT>::internal_gen_boundaries(const uword N) const
  691. {
  692. arma_extra_debug_sigprint();
  693. #if defined(ARMA_USE_OPENMP)
  694. const uword n_threads_avail = uword(omp_get_max_threads());
  695. const uword n_threads = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1;
  696. #else
  697. static const uword n_threads = 1;
  698. #endif
  699. // get_cout_stream() << "gmm_full::internal_gen_boundaries(): n_threads: " << n_threads << '\n';
  700. umat boundaries(2, n_threads);
  701. if(N > 0)
  702. {
  703. const uword chunk_size = N / n_threads;
  704. uword count = 0;
  705. for(uword t=0; t<n_threads; t++)
  706. {
  707. boundaries.at(0,t) = count;
  708. count += chunk_size;
  709. boundaries.at(1,t) = count-1;
  710. }
  711. boundaries.at(1,n_threads-1) = N - 1;
  712. }
  713. else
  714. {
  715. boundaries.zeros();
  716. }
  717. // get_cout_stream() << "gmm_full::internal_gen_boundaries(): boundaries: " << '\n' << boundaries << '\n';
  718. return boundaries;
  719. }
  720. template<typename eT>
  721. inline
  722. eT
  723. gmm_full<eT>::internal_scalar_log_p(const eT* x) const
  724. {
  725. arma_extra_debug_sigprint();
  726. const eT* log_hefts_mem = log_hefts.mem;
  727. const uword N_gaus = means.n_cols;
  728. if(N_gaus > 0)
  729. {
  730. eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0];
  731. for(uword g=1; g < N_gaus; ++g)
  732. {
  733. const eT log_val = internal_scalar_log_p(x, g) + log_hefts_mem[g];
  734. log_sum = log_add_exp(log_sum, log_val);
  735. }
  736. return log_sum;
  737. }
  738. else
  739. {
  740. return -Datum<eT>::inf;
  741. }
  742. }
  743. template<typename eT>
  744. inline
  745. eT
  746. gmm_full<eT>::internal_scalar_log_p(const eT* x, const uword g) const
  747. {
  748. arma_extra_debug_sigprint();
  749. const uword N_dims = means.n_rows;
  750. const eT* mean_mem = means.colptr(g);
  751. eT outer_acc = eT(0);
  752. const eT* inv_fcov_coldata = inv_fcovs.slice(g).memptr();
  753. for(uword i=0; i < N_dims; ++i)
  754. {
  755. eT inner_acc = eT(0);
  756. for(uword j=0; j < N_dims; ++j)
  757. {
  758. inner_acc += (x[j] - mean_mem[j]) * inv_fcov_coldata[j];
  759. }
  760. inv_fcov_coldata += N_dims;
  761. outer_acc += inner_acc * (x[i] - mean_mem[i]);
  762. }
  763. return eT(-0.5)*outer_acc + log_det_etc.mem[g];
  764. }
  765. template<typename eT>
  766. inline
  767. Row<eT>
  768. gmm_full<eT>::internal_vec_log_p(const Mat<eT>& X) const
  769. {
  770. arma_extra_debug_sigprint();
  771. const uword N_dims = means.n_rows;
  772. const uword N_samples = X.n_cols;
  773. arma_debug_check( (X.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
  774. Row<eT> out(N_samples);
  775. if(N_samples > 0)
  776. {
  777. #if defined(ARMA_USE_OPENMP)
  778. {
  779. const umat boundaries = internal_gen_boundaries(N_samples);
  780. const uword n_threads = boundaries.n_cols;
  781. #pragma omp parallel for schedule(static)
  782. for(uword t=0; t < n_threads; ++t)
  783. {
  784. const uword start_index = boundaries.at(0,t);
  785. const uword end_index = boundaries.at(1,t);
  786. eT* out_mem = out.memptr();
  787. for(uword i=start_index; i <= end_index; ++i)
  788. {
  789. out_mem[i] = internal_scalar_log_p( X.colptr(i) );
  790. }
  791. }
  792. }
  793. #else
  794. {
  795. eT* out_mem = out.memptr();
  796. for(uword i=0; i < N_samples; ++i)
  797. {
  798. out_mem[i] = internal_scalar_log_p( X.colptr(i) );
  799. }
  800. }
  801. #endif
  802. }
  803. return out;
  804. }
  805. template<typename eT>
  806. inline
  807. Row<eT>
  808. gmm_full<eT>::internal_vec_log_p(const Mat<eT>& X, const uword gaus_id) const
  809. {
  810. arma_extra_debug_sigprint();
  811. const uword N_dims = means.n_rows;
  812. const uword N_samples = X.n_cols;
  813. arma_debug_check( (X.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
  814. arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::log_p(): specified gaussian is out of range" );
  815. Row<eT> out(N_samples);
  816. if(N_samples > 0)
  817. {
  818. #if defined(ARMA_USE_OPENMP)
  819. {
  820. const umat boundaries = internal_gen_boundaries(N_samples);
  821. const uword n_threads = boundaries.n_cols;
  822. #pragma omp parallel for schedule(static)
  823. for(uword t=0; t < n_threads; ++t)
  824. {
  825. const uword start_index = boundaries.at(0,t);
  826. const uword end_index = boundaries.at(1,t);
  827. eT* out_mem = out.memptr();
  828. for(uword i=start_index; i <= end_index; ++i)
  829. {
  830. out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
  831. }
  832. }
  833. }
  834. #else
  835. {
  836. eT* out_mem = out.memptr();
  837. for(uword i=0; i < N_samples; ++i)
  838. {
  839. out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
  840. }
  841. }
  842. #endif
  843. }
  844. return out;
  845. }
  846. template<typename eT>
  847. inline
  848. eT
  849. gmm_full<eT>::internal_sum_log_p(const Mat<eT>& X) const
  850. {
  851. arma_extra_debug_sigprint();
  852. arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::sum_log_p(): incompatible dimensions" );
  853. const uword N = X.n_cols;
  854. if(N == 0) { return (-Datum<eT>::inf); }
  855. #if defined(ARMA_USE_OPENMP)
  856. {
  857. const umat boundaries = internal_gen_boundaries(N);
  858. const uword n_threads = boundaries.n_cols;
  859. Col<eT> t_accs(n_threads, fill::zeros);
  860. #pragma omp parallel for schedule(static)
  861. for(uword t=0; t < n_threads; ++t)
  862. {
  863. const uword start_index = boundaries.at(0,t);
  864. const uword end_index = boundaries.at(1,t);
  865. eT t_acc = eT(0);
  866. for(uword i=start_index; i <= end_index; ++i)
  867. {
  868. t_acc += internal_scalar_log_p( X.colptr(i) );
  869. }
  870. t_accs[t] = t_acc;
  871. }
  872. return eT(accu(t_accs));
  873. }
  874. #else
  875. {
  876. eT acc = eT(0);
  877. for(uword i=0; i<N; ++i)
  878. {
  879. acc += internal_scalar_log_p( X.colptr(i) );
  880. }
  881. return acc;
  882. }
  883. #endif
  884. }
  885. template<typename eT>
  886. inline
  887. eT
  888. gmm_full<eT>::internal_sum_log_p(const Mat<eT>& X, const uword gaus_id) const
  889. {
  890. arma_extra_debug_sigprint();
  891. arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::sum_log_p(): incompatible dimensions" );
  892. arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::sum_log_p(): specified gaussian is out of range" );
  893. const uword N = X.n_cols;
  894. if(N == 0) { return (-Datum<eT>::inf); }
  895. #if defined(ARMA_USE_OPENMP)
  896. {
  897. const umat boundaries = internal_gen_boundaries(N);
  898. const uword n_threads = boundaries.n_cols;
  899. Col<eT> t_accs(n_threads, fill::zeros);
  900. #pragma omp parallel for schedule(static)
  901. for(uword t=0; t < n_threads; ++t)
  902. {
  903. const uword start_index = boundaries.at(0,t);
  904. const uword end_index = boundaries.at(1,t);
  905. eT t_acc = eT(0);
  906. for(uword i=start_index; i <= end_index; ++i)
  907. {
  908. t_acc += internal_scalar_log_p( X.colptr(i), gaus_id );
  909. }
  910. t_accs[t] = t_acc;
  911. }
  912. return eT(accu(t_accs));
  913. }
  914. #else
  915. {
  916. eT acc = eT(0);
  917. for(uword i=0; i<N; ++i)
  918. {
  919. acc += internal_scalar_log_p( X.colptr(i), gaus_id );
  920. }
  921. return acc;
  922. }
  923. #endif
  924. }
  925. template<typename eT>
  926. inline
  927. eT
  928. gmm_full<eT>::internal_avg_log_p(const Mat<eT>& X) const
  929. {
  930. arma_extra_debug_sigprint();
  931. const uword N_dims = means.n_rows;
  932. const uword N_samples = X.n_cols;
  933. arma_debug_check( (X.n_rows != N_dims), "gmm_full::avg_log_p(): incompatible dimensions" );
  934. if(N_samples == 0) { return (-Datum<eT>::inf); }
  935. #if defined(ARMA_USE_OPENMP)
  936. {
  937. const umat boundaries = internal_gen_boundaries(N_samples);
  938. const uword n_threads = boundaries.n_cols;
  939. field< running_mean_scalar<eT> > t_running_means(n_threads);
  940. #pragma omp parallel for schedule(static)
  941. for(uword t=0; t < n_threads; ++t)
  942. {
  943. const uword start_index = boundaries.at(0,t);
  944. const uword end_index = boundaries.at(1,t);
  945. running_mean_scalar<eT>& current_running_mean = t_running_means[t];
  946. for(uword i=start_index; i <= end_index; ++i)
  947. {
  948. current_running_mean( internal_scalar_log_p( X.colptr(i) ) );
  949. }
  950. }
  951. eT avg = eT(0);
  952. for(uword t=0; t < n_threads; ++t)
  953. {
  954. running_mean_scalar<eT>& current_running_mean = t_running_means[t];
  955. const eT w = eT(current_running_mean.count()) / eT(N_samples);
  956. avg += w * current_running_mean.mean();
  957. }
  958. return avg;
  959. }
  960. #else
  961. {
  962. running_mean_scalar<eT> running_mean;
  963. for(uword i=0; i < N_samples; ++i)
  964. {
  965. running_mean( internal_scalar_log_p( X.colptr(i) ) );
  966. }
  967. return running_mean.mean();
  968. }
  969. #endif
  970. }
  971. template<typename eT>
  972. inline
  973. eT
  974. gmm_full<eT>::internal_avg_log_p(const Mat<eT>& X, const uword gaus_id) const
  975. {
  976. arma_extra_debug_sigprint();
  977. const uword N_dims = means.n_rows;
  978. const uword N_samples = X.n_cols;
  979. arma_debug_check( (X.n_rows != N_dims), "gmm_full::avg_log_p(): incompatible dimensions" );
  980. arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::avg_log_p(): specified gaussian is out of range" );
  981. if(N_samples == 0) { return (-Datum<eT>::inf); }
  982. #if defined(ARMA_USE_OPENMP)
  983. {
  984. const umat boundaries = internal_gen_boundaries(N_samples);
  985. const uword n_threads = boundaries.n_cols;
  986. field< running_mean_scalar<eT> > t_running_means(n_threads);
  987. #pragma omp parallel for schedule(static)
  988. for(uword t=0; t < n_threads; ++t)
  989. {
  990. const uword start_index = boundaries.at(0,t);
  991. const uword end_index = boundaries.at(1,t);
  992. running_mean_scalar<eT>& current_running_mean = t_running_means[t];
  993. for(uword i=start_index; i <= end_index; ++i)
  994. {
  995. current_running_mean( internal_scalar_log_p( X.colptr(i), gaus_id) );
  996. }
  997. }
  998. eT avg = eT(0);
  999. for(uword t=0; t < n_threads; ++t)
  1000. {
  1001. running_mean_scalar<eT>& current_running_mean = t_running_means[t];
  1002. const eT w = eT(current_running_mean.count()) / eT(N_samples);
  1003. avg += w * current_running_mean.mean();
  1004. }
  1005. return avg;
  1006. }
  1007. #else
  1008. {
  1009. running_mean_scalar<eT> running_mean;
  1010. for(uword i=0; i<N_samples; ++i)
  1011. {
  1012. running_mean( internal_scalar_log_p( X.colptr(i), gaus_id ) );
  1013. }
  1014. return running_mean.mean();
  1015. }
  1016. #endif
  1017. }
  1018. template<typename eT>
  1019. inline
  1020. uword
  1021. gmm_full<eT>::internal_scalar_assign(const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
  1022. {
  1023. arma_extra_debug_sigprint();
  1024. const uword N_dims = means.n_rows;
  1025. const uword N_gaus = means.n_cols;
  1026. arma_debug_check( (X.n_rows != N_dims), "gmm_full::assign(): incompatible dimensions" );
  1027. arma_debug_check( (N_gaus == 0), "gmm_full::assign(): model has no means" );
  1028. const eT* X_mem = X.colptr(0);
  1029. if(dist_mode == eucl_dist)
  1030. {
  1031. eT best_dist = Datum<eT>::inf;
  1032. uword best_g = 0;
  1033. for(uword g=0; g < N_gaus; ++g)
  1034. {
  1035. const eT tmp_dist = distance<eT,1>::eval(N_dims, X_mem, means.colptr(g), X_mem);
  1036. if(tmp_dist <= best_dist)
  1037. {
  1038. best_dist = tmp_dist;
  1039. best_g = g;
  1040. }
  1041. }
  1042. return best_g;
  1043. }
  1044. else
  1045. if(dist_mode == prob_dist)
  1046. {
  1047. const eT* log_hefts_mem = log_hefts.memptr();
  1048. eT best_p = -Datum<eT>::inf;
  1049. uword best_g = 0;
  1050. for(uword g=0; g < N_gaus; ++g)
  1051. {
  1052. const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g];
  1053. if(tmp_p >= best_p)
  1054. {
  1055. best_p = tmp_p;
  1056. best_g = g;
  1057. }
  1058. }
  1059. return best_g;
  1060. }
  1061. else
  1062. {
  1063. arma_debug_check(true, "gmm_full::assign(): unsupported distance mode");
  1064. }
  1065. return uword(0);
  1066. }
  1067. template<typename eT>
  1068. inline
  1069. void
  1070. gmm_full<eT>::internal_vec_assign(urowvec& out, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
  1071. {
  1072. arma_extra_debug_sigprint();
  1073. const uword N_dims = means.n_rows;
  1074. const uword N_gaus = means.n_cols;
  1075. arma_debug_check( (X.n_rows != N_dims), "gmm_full::assign(): incompatible dimensions" );
  1076. const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0;
  1077. out.set_size(1,X_n_cols);
  1078. uword* out_mem = out.memptr();
  1079. if(dist_mode == eucl_dist)
  1080. {
  1081. #if defined(ARMA_USE_OPENMP)
  1082. {
  1083. #pragma omp parallel for schedule(static)
  1084. for(uword i=0; i<X_n_cols; ++i)
  1085. {
  1086. const eT* X_colptr = X.colptr(i);
  1087. eT best_dist = Datum<eT>::inf;
  1088. uword best_g = 0;
  1089. for(uword g=0; g<N_gaus; ++g)
  1090. {
  1091. const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
  1092. if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
  1093. }
  1094. out_mem[i] = best_g;
  1095. }
  1096. }
  1097. #else
  1098. {
  1099. for(uword i=0; i<X_n_cols; ++i)
  1100. {
  1101. const eT* X_colptr = X.colptr(i);
  1102. eT best_dist = Datum<eT>::inf;
  1103. uword best_g = 0;
  1104. for(uword g=0; g<N_gaus; ++g)
  1105. {
  1106. const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
  1107. if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
  1108. }
  1109. out_mem[i] = best_g;
  1110. }
  1111. }
  1112. #endif
  1113. }
  1114. else
  1115. if(dist_mode == prob_dist)
  1116. {
  1117. #if defined(ARMA_USE_OPENMP)
  1118. {
  1119. const umat boundaries = internal_gen_boundaries(X_n_cols);
  1120. const uword n_threads = boundaries.n_cols;
  1121. const eT* log_hefts_mem = log_hefts.memptr();
  1122. #pragma omp parallel for schedule(static)
  1123. for(uword t=0; t < n_threads; ++t)
  1124. {
  1125. const uword start_index = boundaries.at(0,t);
  1126. const uword end_index = boundaries.at(1,t);
  1127. for(uword i=start_index; i <= end_index; ++i)
  1128. {
  1129. const eT* X_colptr = X.colptr(i);
  1130. eT best_p = -Datum<eT>::inf;
  1131. uword best_g = 0;
  1132. for(uword g=0; g<N_gaus; ++g)
  1133. {
  1134. const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
  1135. if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
  1136. }
  1137. out_mem[i] = best_g;
  1138. }
  1139. }
  1140. }
  1141. #else
  1142. {
  1143. const eT* log_hefts_mem = log_hefts.memptr();
  1144. for(uword i=0; i<X_n_cols; ++i)
  1145. {
  1146. const eT* X_colptr = X.colptr(i);
  1147. eT best_p = -Datum<eT>::inf;
  1148. uword best_g = 0;
  1149. for(uword g=0; g<N_gaus; ++g)
  1150. {
  1151. const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
  1152. if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
  1153. }
  1154. out_mem[i] = best_g;
  1155. }
  1156. }
  1157. #endif
  1158. }
  1159. else
  1160. {
  1161. arma_debug_check(true, "gmm_full::assign(): unsupported distance mode");
  1162. }
  1163. }
  1164. template<typename eT>
  1165. inline
  1166. void
  1167. gmm_full<eT>::internal_raw_hist(urowvec& hist, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
  1168. {
  1169. arma_extra_debug_sigprint();
  1170. const uword N_dims = means.n_rows;
  1171. const uword N_gaus = means.n_cols;
  1172. const uword X_n_cols = X.n_cols;
  1173. hist.zeros(N_gaus);
  1174. if(N_gaus == 0) { return; }
  1175. #if defined(ARMA_USE_OPENMP)
  1176. {
  1177. const umat boundaries = internal_gen_boundaries(X_n_cols);
  1178. const uword n_threads = boundaries.n_cols;
  1179. field<urowvec> thread_hist(n_threads);
  1180. for(uword t=0; t < n_threads; ++t) { thread_hist(t).zeros(N_gaus); }
  1181. if(dist_mode == eucl_dist)
  1182. {
  1183. #pragma omp parallel for schedule(static)
  1184. for(uword t=0; t < n_threads; ++t)
  1185. {
  1186. uword* thread_hist_mem = thread_hist(t).memptr();
  1187. const uword start_index = boundaries.at(0,t);
  1188. const uword end_index = boundaries.at(1,t);
  1189. for(uword i=start_index; i <= end_index; ++i)
  1190. {
  1191. const eT* X_colptr = X.colptr(i);
  1192. eT best_dist = Datum<eT>::inf;
  1193. uword best_g = 0;
  1194. for(uword g=0; g < N_gaus; ++g)
  1195. {
  1196. const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
  1197. if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
  1198. }
  1199. thread_hist_mem[best_g]++;
  1200. }
  1201. }
  1202. }
  1203. else
  1204. if(dist_mode == prob_dist)
  1205. {
  1206. const eT* log_hefts_mem = log_hefts.memptr();
  1207. #pragma omp parallel for schedule(static)
  1208. for(uword t=0; t < n_threads; ++t)
  1209. {
  1210. uword* thread_hist_mem = thread_hist(t).memptr();
  1211. const uword start_index = boundaries.at(0,t);
  1212. const uword end_index = boundaries.at(1,t);
  1213. for(uword i=start_index; i <= end_index; ++i)
  1214. {
  1215. const eT* X_colptr = X.colptr(i);
  1216. eT best_p = -Datum<eT>::inf;
  1217. uword best_g = 0;
  1218. for(uword g=0; g < N_gaus; ++g)
  1219. {
  1220. const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
  1221. if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
  1222. }
  1223. thread_hist_mem[best_g]++;
  1224. }
  1225. }
  1226. }
  1227. // reduction
  1228. for(uword t=0; t < n_threads; ++t)
  1229. {
  1230. hist += thread_hist(t);
  1231. }
  1232. }
  1233. #else
  1234. {
  1235. uword* hist_mem = hist.memptr();
  1236. if(dist_mode == eucl_dist)
  1237. {
  1238. for(uword i=0; i<X_n_cols; ++i)
  1239. {
  1240. const eT* X_colptr = X.colptr(i);
  1241. eT best_dist = Datum<eT>::inf;
  1242. uword best_g = 0;
  1243. for(uword g=0; g < N_gaus; ++g)
  1244. {
  1245. const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
  1246. if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
  1247. }
  1248. hist_mem[best_g]++;
  1249. }
  1250. }
  1251. else
  1252. if(dist_mode == prob_dist)
  1253. {
  1254. const eT* log_hefts_mem = log_hefts.memptr();
  1255. for(uword i=0; i<X_n_cols; ++i)
  1256. {
  1257. const eT* X_colptr = X.colptr(i);
  1258. eT best_p = -Datum<eT>::inf;
  1259. uword best_g = 0;
  1260. for(uword g=0; g < N_gaus; ++g)
  1261. {
  1262. const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
  1263. if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
  1264. }
  1265. hist_mem[best_g]++;
  1266. }
  1267. }
  1268. }
  1269. #endif
  1270. }
  1271. template<typename eT>
  1272. template<uword dist_id>
  1273. inline
  1274. void
  1275. gmm_full<eT>::generate_initial_means(const Mat<eT>& X, const gmm_seed_mode& seed_mode)
  1276. {
  1277. arma_extra_debug_sigprint();
  1278. const uword N_dims = means.n_rows;
  1279. const uword N_gaus = means.n_cols;
  1280. if( (seed_mode == static_subset) || (seed_mode == random_subset) )
  1281. {
  1282. uvec initial_indices;
  1283. if(seed_mode == static_subset) { initial_indices = linspace<uvec>(0, X.n_cols-1, N_gaus); }
  1284. else if(seed_mode == random_subset) { initial_indices = randperm<uvec>(X.n_cols, N_gaus); }
  1285. // initial_indices.print("initial_indices:");
  1286. access::rw(means) = X.cols(initial_indices);
  1287. }
  1288. else
  1289. if( (seed_mode == static_spread) || (seed_mode == random_spread) )
  1290. {
  1291. // going through all of the samples can be extremely time consuming;
  1292. // instead, if there are enough samples, randomly choose samples with probability 0.1
  1293. const bool use_sampling = ((X.n_cols/uword(100)) > N_gaus);
  1294. const uword step = (use_sampling) ? uword(10) : uword(1);
  1295. uword start_index = 0;
  1296. if(seed_mode == static_spread) { start_index = X.n_cols / 2; }
  1297. else if(seed_mode == random_spread) { start_index = as_scalar(randi<uvec>(1, distr_param(0,X.n_cols-1))); }
  1298. access::rw(means).col(0) = X.unsafe_col(start_index);
  1299. const eT* mah_aux_mem = mah_aux.memptr();
  1300. running_stat<double> rs;
  1301. for(uword g=1; g < N_gaus; ++g)
  1302. {
  1303. eT max_dist = eT(0);
  1304. uword best_i = uword(0);
  1305. uword start_i = uword(0);
  1306. if(use_sampling)
  1307. {
  1308. uword start_i_proposed = uword(0);
  1309. if(seed_mode == static_spread) { start_i_proposed = g % uword(10); }
  1310. if(seed_mode == random_spread) { start_i_proposed = as_scalar(randi<uvec>(1, distr_param(0,9))); }
  1311. if(start_i_proposed < X.n_cols) { start_i = start_i_proposed; }
  1312. }
  1313. for(uword i=start_i; i < X.n_cols; i += step)
  1314. {
  1315. rs.reset();
  1316. const eT* X_colptr = X.colptr(i);
  1317. bool ignore_i = false;
  1318. // find the average distance between sample i and the means so far
  1319. for(uword h = 0; h < g; ++h)
  1320. {
  1321. const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(h), mah_aux_mem);
  1322. // ignore sample already selected as a mean
  1323. if(dist == eT(0)) { ignore_i = true; break; }
  1324. else { rs(dist); }
  1325. }
  1326. if( (rs.mean() >= max_dist) && (ignore_i == false))
  1327. {
  1328. max_dist = eT(rs.mean()); best_i = i;
  1329. }
  1330. }
  1331. // set the mean to the sample that is the furthest away from the means so far
  1332. access::rw(means).col(g) = X.unsafe_col(best_i);
  1333. }
  1334. }
  1335. // get_cout_stream() << "generate_initial_means():" << '\n';
  1336. // means.print();
  1337. }
  1338. template<typename eT>
  1339. template<uword dist_id>
  1340. inline
  1341. void
  1342. gmm_full<eT>::generate_initial_params(const Mat<eT>& X, const eT var_floor)
  1343. {
  1344. arma_extra_debug_sigprint();
  1345. const uword N_dims = means.n_rows;
  1346. const uword N_gaus = means.n_cols;
  1347. const eT* mah_aux_mem = mah_aux.memptr();
  1348. const uword X_n_cols = X.n_cols;
  1349. if(X_n_cols == 0) { return; }
  1350. // as the covariances are calculated via accumulators,
  1351. // the means also need to be calculated via accumulators to ensure numerical consistency
  1352. Mat<eT> acc_means(N_dims, N_gaus, fill::zeros);
  1353. Mat<eT> acc_dcovs(N_dims, N_gaus, fill::zeros);
  1354. Row<uword> acc_hefts(N_gaus, fill::zeros);
  1355. uword* acc_hefts_mem = acc_hefts.memptr();
  1356. #if defined(ARMA_USE_OPENMP)
  1357. {
  1358. const umat boundaries = internal_gen_boundaries(X_n_cols);
  1359. const uword n_threads = boundaries.n_cols;
  1360. field< Mat<eT> > t_acc_means(n_threads);
  1361. field< Mat<eT> > t_acc_dcovs(n_threads);
  1362. field< Row<uword> > t_acc_hefts(n_threads);
  1363. for(uword t=0; t < n_threads; ++t)
  1364. {
  1365. t_acc_means(t).zeros(N_dims, N_gaus);
  1366. t_acc_dcovs(t).zeros(N_dims, N_gaus);
  1367. t_acc_hefts(t).zeros(N_gaus);
  1368. }
  1369. #pragma omp parallel for schedule(static)
  1370. for(uword t=0; t < n_threads; ++t)
  1371. {
  1372. uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
  1373. const uword start_index = boundaries.at(0,t);
  1374. const uword end_index = boundaries.at(1,t);
  1375. for(uword i=start_index; i <= end_index; ++i)
  1376. {
  1377. const eT* X_colptr = X.colptr(i);
  1378. eT min_dist = Datum<eT>::inf;
  1379. uword best_g = 0;
  1380. for(uword g=0; g<N_gaus; ++g)
  1381. {
  1382. const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
  1383. if(dist < min_dist) { min_dist = dist; best_g = g; }
  1384. }
  1385. eT* t_acc_mean = t_acc_means(t).colptr(best_g);
  1386. eT* t_acc_dcov = t_acc_dcovs(t).colptr(best_g);
  1387. for(uword d=0; d<N_dims; ++d)
  1388. {
  1389. const eT x_d = X_colptr[d];
  1390. t_acc_mean[d] += x_d;
  1391. t_acc_dcov[d] += x_d*x_d;
  1392. }
  1393. t_acc_hefts_mem[best_g]++;
  1394. }
  1395. }
  1396. // reduction
  1397. acc_means = t_acc_means(0);
  1398. acc_dcovs = t_acc_dcovs(0);
  1399. acc_hefts = t_acc_hefts(0);
  1400. for(uword t=1; t < n_threads; ++t)
  1401. {
  1402. acc_means += t_acc_means(t);
  1403. acc_dcovs += t_acc_dcovs(t);
  1404. acc_hefts += t_acc_hefts(t);
  1405. }
  1406. }
  1407. #else
  1408. {
  1409. for(uword i=0; i<X_n_cols; ++i)
  1410. {
  1411. const eT* X_colptr = X.colptr(i);
  1412. eT min_dist = Datum<eT>::inf;
  1413. uword best_g = 0;
  1414. for(uword g=0; g<N_gaus; ++g)
  1415. {
  1416. const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
  1417. if(dist < min_dist) { min_dist = dist; best_g = g; }
  1418. }
  1419. eT* acc_mean = acc_means.colptr(best_g);
  1420. eT* acc_dcov = acc_dcovs.colptr(best_g);
  1421. for(uword d=0; d<N_dims; ++d)
  1422. {
  1423. const eT x_d = X_colptr[d];
  1424. acc_mean[d] += x_d;
  1425. acc_dcov[d] += x_d*x_d;
  1426. }
  1427. acc_hefts_mem[best_g]++;
  1428. }
  1429. }
  1430. #endif
  1431. eT* hefts_mem = access::rw(hefts).memptr();
  1432. for(uword g=0; g<N_gaus; ++g)
  1433. {
  1434. const eT* acc_mean = acc_means.colptr(g);
  1435. const eT* acc_dcov = acc_dcovs.colptr(g);
  1436. const uword acc_heft = acc_hefts_mem[g];
  1437. eT* mean = access::rw(means).colptr(g);
  1438. Mat<eT>& fcov = access::rw(fcovs).slice(g);
  1439. fcov.zeros();
  1440. for(uword d=0; d<N_dims; ++d)
  1441. {
  1442. const eT tmp = acc_mean[d] / eT(acc_heft);
  1443. mean[d] = (acc_heft >= 1) ? tmp : eT(0);
  1444. fcov.at(d,d) = (acc_heft >= 2) ? eT((acc_dcov[d] / eT(acc_heft)) - (tmp*tmp)) : eT(var_floor);
  1445. }
  1446. hefts_mem[g] = eT(acc_heft) / eT(X_n_cols);
  1447. }
  1448. em_fix_params(var_floor);
  1449. }
  1450. //! multi-threaded implementation of k-means, inspired by MapReduce
  1451. template<typename eT>
  1452. template<uword dist_id>
  1453. inline
  1454. bool
  1455. gmm_full<eT>::km_iterate(const Mat<eT>& X, const uword max_iter, const bool verbose)
  1456. {
  1457. arma_extra_debug_sigprint();
  1458. if(verbose)
  1459. {
  1460. get_cout_stream().unsetf(ios::showbase);
  1461. get_cout_stream().unsetf(ios::uppercase);
  1462. get_cout_stream().unsetf(ios::showpos);
  1463. get_cout_stream().unsetf(ios::scientific);
  1464. get_cout_stream().setf(ios::right);
  1465. get_cout_stream().setf(ios::fixed);
  1466. }
  1467. const uword X_n_cols = X.n_cols;
  1468. if(X_n_cols == 0) { return true; }
  1469. const uword N_dims = means.n_rows;
  1470. const uword N_gaus = means.n_cols;
  1471. const eT* mah_aux_mem = mah_aux.memptr();
  1472. Mat<eT> acc_means(N_dims, N_gaus, fill::zeros);
  1473. Row<uword> acc_hefts(N_gaus, fill::zeros);
  1474. Row<uword> last_indx(N_gaus, fill::zeros);
  1475. Mat<eT> new_means = means;
  1476. Mat<eT> old_means = means;
  1477. running_mean_scalar<eT> rs_delta;
  1478. #if defined(ARMA_USE_OPENMP)
  1479. const umat boundaries = internal_gen_boundaries(X_n_cols);
  1480. const uword n_threads = boundaries.n_cols;
  1481. field< Mat<eT> > t_acc_means(n_threads);
  1482. field< Row<uword> > t_acc_hefts(n_threads);
  1483. field< Row<uword> > t_last_indx(n_threads);
  1484. #else
  1485. const uword n_threads = 1;
  1486. #endif
  1487. if(verbose) { get_cout_stream() << "gmm_full::learn(): k-means: n_threads: " << n_threads << '\n'; get_cout_stream().flush(); }
  1488. for(uword iter=1; iter <= max_iter; ++iter)
  1489. {
  1490. #if defined(ARMA_USE_OPENMP)
  1491. {
  1492. for(uword t=0; t < n_threads; ++t)
  1493. {
  1494. t_acc_means(t).zeros(N_dims, N_gaus);
  1495. t_acc_hefts(t).zeros(N_gaus);
  1496. t_last_indx(t).zeros(N_gaus);
  1497. }
  1498. #pragma omp parallel for schedule(static)
  1499. for(uword t=0; t < n_threads; ++t)
  1500. {
  1501. Mat<eT>& t_acc_means_t = t_acc_means(t);
  1502. uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
  1503. uword* t_last_indx_mem = t_last_indx(t).memptr();
  1504. const uword start_index = boundaries.at(0,t);
  1505. const uword end_index = boundaries.at(1,t);
  1506. for(uword i=start_index; i <= end_index; ++i)
  1507. {
  1508. const eT* X_colptr = X.colptr(i);
  1509. eT min_dist = Datum<eT>::inf;
  1510. uword best_g = 0;
  1511. for(uword g=0; g<N_gaus; ++g)
  1512. {
  1513. const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
  1514. if(dist < min_dist) { min_dist = dist; best_g = g; }
  1515. }
  1516. eT* t_acc_mean = t_acc_means_t.colptr(best_g);
  1517. for(uword d=0; d<N_dims; ++d) { t_acc_mean[d] += X_colptr[d]; }
  1518. t_acc_hefts_mem[best_g]++;
  1519. t_last_indx_mem[best_g] = i;
  1520. }
  1521. }
  1522. // reduction
  1523. acc_means = t_acc_means(0);
  1524. acc_hefts = t_acc_hefts(0);
  1525. for(uword t=1; t < n_threads; ++t)
  1526. {
  1527. acc_means += t_acc_means(t);
  1528. acc_hefts += t_acc_hefts(t);
  1529. }
  1530. for(uword g=0; g < N_gaus; ++g)
  1531. for(uword t=0; t < n_threads; ++t)
  1532. {
  1533. if( t_acc_hefts(t)(g) >= 1 ) { last_indx(g) = t_last_indx(t)(g); }
  1534. }
  1535. }
  1536. #else
  1537. {
  1538. uword* acc_hefts_mem = acc_hefts.memptr();
  1539. uword* last_indx_mem = last_indx.memptr();
  1540. for(uword i=0; i < X_n_cols; ++i)
  1541. {
  1542. const eT* X_colptr = X.colptr(i);
  1543. eT min_dist = Datum<eT>::inf;
  1544. uword best_g = 0;
  1545. for(uword g=0; g<N_gaus; ++g)
  1546. {
  1547. const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
  1548. if(dist < min_dist) { min_dist = dist; best_g = g; }
  1549. }
  1550. eT* acc_mean = acc_means.colptr(best_g);
  1551. for(uword d=0; d<N_dims; ++d) { acc_mean[d] += X_colptr[d]; }
  1552. acc_hefts_mem[best_g]++;
  1553. last_indx_mem[best_g] = i;
  1554. }
  1555. }
  1556. #endif
  1557. // generate new means
  1558. uword* acc_hefts_mem = acc_hefts.memptr();
  1559. for(uword g=0; g < N_gaus; ++g)
  1560. {
  1561. const eT* acc_mean = acc_means.colptr(g);
  1562. const uword acc_heft = acc_hefts_mem[g];
  1563. eT* new_mean = access::rw(new_means).colptr(g);
  1564. for(uword d=0; d<N_dims; ++d)
  1565. {
  1566. new_mean[d] = (acc_heft >= 1) ? (acc_mean[d] / eT(acc_heft)) : eT(0);
  1567. }
  1568. }
  1569. // heuristics to resurrect dead means
  1570. const uvec dead_gs = find(acc_hefts == uword(0));
  1571. if(dead_gs.n_elem > 0)
  1572. {
  1573. if(verbose) { get_cout_stream() << "gmm_full::learn(): k-means: recovering from dead means\n"; get_cout_stream().flush(); }
  1574. uword* last_indx_mem = last_indx.memptr();
  1575. const uvec live_gs = sort( find(acc_hefts >= uword(2)), "descend" );
  1576. if(live_gs.n_elem == 0) { return false; }
  1577. uword live_gs_count = 0;
  1578. for(uword dead_gs_count = 0; dead_gs_count < dead_gs.n_elem; ++dead_gs_count)
  1579. {
  1580. const uword dead_g_id = dead_gs(dead_gs_count);
  1581. uword proposed_i = 0;
  1582. if(live_gs_count < live_gs.n_elem)
  1583. {
  1584. const uword live_g_id = live_gs(live_gs_count); ++live_gs_count;
  1585. if(live_g_id == dead_g_id) { return false; }
  1586. // recover by using a sample from a known good mean
  1587. proposed_i = last_indx_mem[live_g_id];
  1588. }
  1589. else
  1590. {
  1591. // recover by using a randomly seleced sample (last resort)
  1592. proposed_i = as_scalar(randi<uvec>(1, distr_param(0,X_n_cols-1)));
  1593. }
  1594. if(proposed_i >= X_n_cols) { return false; }
  1595. new_means.col(dead_g_id) = X.col(proposed_i);
  1596. }
  1597. }
  1598. rs_delta.reset();
  1599. for(uword g=0; g < N_gaus; ++g)
  1600. {
  1601. rs_delta( distance<eT,dist_id>::eval(N_dims, old_means.colptr(g), new_means.colptr(g), mah_aux_mem) );
  1602. }
  1603. if(verbose)
  1604. {
  1605. get_cout_stream() << "gmm_full::learn(): k-means: iteration: ";
  1606. get_cout_stream().unsetf(ios::scientific);
  1607. get_cout_stream().setf(ios::fixed);
  1608. get_cout_stream().width(std::streamsize(4));
  1609. get_cout_stream() << iter;
  1610. get_cout_stream() << " delta: ";
  1611. get_cout_stream().unsetf(ios::fixed);
  1612. //get_cout_stream().setf(ios::scientific);
  1613. get_cout_stream() << rs_delta.mean() << '\n';
  1614. get_cout_stream().flush();
  1615. }
  1616. arma::swap(old_means, new_means);
  1617. if(rs_delta.mean() <= Datum<eT>::eps) { break; }
  1618. }
  1619. access::rw(means) = old_means;
  1620. if(means.is_finite() == false) { return false; }
  1621. return true;
  1622. }
  1623. //! multi-threaded implementation of Expectation-Maximisation, inspired by MapReduce
  1624. template<typename eT>
  1625. inline
  1626. bool
  1627. gmm_full<eT>::em_iterate(const Mat<eT>& X, const uword max_iter, const eT var_floor, const bool verbose)
  1628. {
  1629. arma_extra_debug_sigprint();
  1630. const uword N_dims = means.n_rows;
  1631. const uword N_gaus = means.n_cols;
  1632. if(verbose)
  1633. {
  1634. get_cout_stream().unsetf(ios::showbase);
  1635. get_cout_stream().unsetf(ios::uppercase);
  1636. get_cout_stream().unsetf(ios::showpos);
  1637. get_cout_stream().unsetf(ios::scientific);
  1638. get_cout_stream().setf(ios::right);
  1639. get_cout_stream().setf(ios::fixed);
  1640. }
  1641. const umat boundaries = internal_gen_boundaries(X.n_cols);
  1642. const uword n_threads = boundaries.n_cols;
  1643. field< Mat<eT> > t_acc_means(n_threads);
  1644. field< Cube<eT> > t_acc_fcovs(n_threads);
  1645. field< Col<eT> > t_acc_norm_lhoods(n_threads);
  1646. field< Col<eT> > t_gaus_log_lhoods(n_threads);
  1647. Col<eT> t_progress_log_lhood(n_threads);
  1648. for(uword t=0; t<n_threads; t++)
  1649. {
  1650. t_acc_means[t].set_size(N_dims, N_gaus);
  1651. t_acc_fcovs[t].set_size(N_dims, N_dims, N_gaus);
  1652. t_acc_norm_lhoods[t].set_size(N_gaus);
  1653. t_gaus_log_lhoods[t].set_size(N_gaus);
  1654. }
  1655. if(verbose)
  1656. {
  1657. get_cout_stream() << "gmm_full::learn(): EM: n_threads: " << n_threads << '\n';
  1658. }
  1659. eT old_avg_log_p = -Datum<eT>::inf;
  1660. const bool calc_chol = false;
  1661. for(uword iter=1; iter <= max_iter; ++iter)
  1662. {
  1663. init_constants(calc_chol);
  1664. em_update_params(X, boundaries, t_acc_means, t_acc_fcovs, t_acc_norm_lhoods, t_gaus_log_lhoods, t_progress_log_lhood, var_floor);
  1665. em_fix_params(var_floor);
  1666. const eT new_avg_log_p = accu(t_progress_log_lhood) / eT(t_progress_log_lhood.n_elem);
  1667. if(verbose)
  1668. {
  1669. get_cout_stream() << "gmm_full::learn(): EM: iteration: ";
  1670. get_cout_stream().unsetf(ios::scientific);
  1671. get_cout_stream().setf(ios::fixed);
  1672. get_cout_stream().width(std::streamsize(4));
  1673. get_cout_stream() << iter;
  1674. get_cout_stream() << " avg_log_p: ";
  1675. get_cout_stream().unsetf(ios::fixed);
  1676. //get_cout_stream().setf(ios::scientific);
  1677. get_cout_stream() << new_avg_log_p << '\n';
  1678. get_cout_stream().flush();
  1679. }
  1680. if(arma_isfinite(new_avg_log_p) == false) { return false; }
  1681. if(std::abs(old_avg_log_p - new_avg_log_p) <= Datum<eT>::eps) { break; }
  1682. old_avg_log_p = new_avg_log_p;
  1683. }
  1684. for(uword g=0; g < N_gaus; ++g)
  1685. {
  1686. const Mat<eT>& fcov = fcovs.slice(g);
  1687. if(any(vectorise(fcov.diag()) <= eT(0))) { return false; }
  1688. }
  1689. if(means.is_finite() == false) { return false; }
  1690. if(fcovs.is_finite() == false) { return false; }
  1691. if(hefts.is_finite() == false) { return false; }
  1692. return true;
  1693. }
  1694. template<typename eT>
  1695. inline
  1696. void
  1697. gmm_full<eT>::em_update_params
  1698. (
  1699. const Mat<eT>& X,
  1700. const umat& boundaries,
  1701. field< Mat<eT> >& t_acc_means,
  1702. field< Cube<eT> >& t_acc_fcovs,
  1703. field< Col<eT> >& t_acc_norm_lhoods,
  1704. field< Col<eT> >& t_gaus_log_lhoods,
  1705. Col<eT>& t_progress_log_lhood,
  1706. const eT var_floor
  1707. )
  1708. {
  1709. arma_extra_debug_sigprint();
  1710. const uword n_threads = boundaries.n_cols;
  1711. // em_generate_acc() is the "map" operation, which produces partial accumulators for means, diagonal covariances and hefts
  1712. #if defined(ARMA_USE_OPENMP)
  1713. {
  1714. #pragma omp parallel for schedule(static)
  1715. for(uword t=0; t<n_threads; t++)
  1716. {
  1717. Mat<eT>& acc_means = t_acc_means[t];
  1718. Cube<eT>& acc_fcovs = t_acc_fcovs[t];
  1719. Col<eT>& acc_norm_lhoods = t_acc_norm_lhoods[t];
  1720. Col<eT>& gaus_log_lhoods = t_gaus_log_lhoods[t];
  1721. eT& progress_log_lhood = t_progress_log_lhood[t];
  1722. em_generate_acc(X, boundaries.at(0,t), boundaries.at(1,t), acc_means, acc_fcovs, acc_norm_lhoods, gaus_log_lhoods, progress_log_lhood);
  1723. }
  1724. }
  1725. #else
  1726. {
  1727. em_generate_acc(X, boundaries.at(0,0), boundaries.at(1,0), t_acc_means[0], t_acc_fcovs[0], t_acc_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]);
  1728. }
  1729. #endif
  1730. const uword N_dims = means.n_rows;
  1731. const uword N_gaus = means.n_cols;
  1732. Mat<eT>& final_acc_means = t_acc_means[0];
  1733. Cube<eT>& final_acc_fcovs = t_acc_fcovs[0];
  1734. Col<eT>& final_acc_norm_lhoods = t_acc_norm_lhoods[0];
  1735. // the "reduce" operation, which combines the partial accumulators produced by the separate threads
  1736. for(uword t=1; t<n_threads; t++)
  1737. {
  1738. final_acc_means += t_acc_means[t];
  1739. final_acc_fcovs += t_acc_fcovs[t];
  1740. final_acc_norm_lhoods += t_acc_norm_lhoods[t];
  1741. }
  1742. eT* hefts_mem = access::rw(hefts).memptr();
  1743. Mat<eT> mean_outer(N_dims, N_dims);
  1744. //// update each component without sanity checking
  1745. //for(uword g=0; g < N_gaus; ++g)
  1746. // {
  1747. // const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
  1748. //
  1749. // hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
  1750. //
  1751. // eT* mean_mem = access::rw(means).colptr(g);
  1752. // eT* acc_mean_mem = final_acc_means.colptr(g);
  1753. //
  1754. // for(uword d=0; d < N_dims; ++d)
  1755. // {
  1756. // mean_mem[d] = acc_mean_mem[d] / acc_norm_lhood;
  1757. // }
  1758. //
  1759. // const Col<eT> mean(mean_mem, N_dims, false, true);
  1760. //
  1761. // mean_outer = mean * mean.t();
  1762. //
  1763. // Mat<eT>& fcov = access::rw(fcovs).slice(g);
  1764. // Mat<eT>& acc_fcov = final_acc_fcovs.slice(g);
  1765. //
  1766. // fcov = acc_fcov / acc_norm_lhood - mean_outer;
  1767. // }
  1768. // conditionally update each component; if only a subset of the hefts was updated, em_fix_params() will sanitise them
  1769. for(uword g=0; g < N_gaus; ++g)
  1770. {
  1771. const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
  1772. if(arma_isfinite(acc_norm_lhood) == false) { continue; }
  1773. eT* acc_mean_mem = final_acc_means.colptr(g);
  1774. for(uword d=0; d < N_dims; ++d)
  1775. {
  1776. acc_mean_mem[d] /= acc_norm_lhood;
  1777. }
  1778. const Col<eT> new_mean(acc_mean_mem, N_dims, false, true);
  1779. mean_outer = new_mean * new_mean.t();
  1780. Mat<eT>& acc_fcov = final_acc_fcovs.slice(g);
  1781. acc_fcov /= acc_norm_lhood;
  1782. acc_fcov -= mean_outer;
  1783. for(uword d=0; d < N_dims; ++d)
  1784. {
  1785. eT& val = acc_fcov.at(d,d);
  1786. if(val < var_floor) { val = var_floor; }
  1787. }
  1788. if(acc_fcov.is_finite() == false) { continue; }
  1789. eT log_det_val = eT(0);
  1790. eT log_det_sign = eT(0);
  1791. log_det(log_det_val, log_det_sign, acc_fcov);
  1792. const bool log_det_ok = ( (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) );
  1793. const bool inv_ok = (log_det_ok) ? bool(auxlib::inv_sympd(mean_outer, acc_fcov)) : bool(false); // mean_outer is used as a junk matrix
  1794. if(log_det_ok && inv_ok)
  1795. {
  1796. hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
  1797. eT* mean_mem = access::rw(means).colptr(g);
  1798. for(uword d=0; d < N_dims; ++d)
  1799. {
  1800. mean_mem[d] = acc_mean_mem[d];
  1801. }
  1802. Mat<eT>& fcov = access::rw(fcovs).slice(g);
  1803. fcov = acc_fcov;
  1804. }
  1805. }
  1806. }
  1807. template<typename eT>
  1808. inline
  1809. void
  1810. gmm_full<eT>::em_generate_acc
  1811. (
  1812. const Mat<eT>& X,
  1813. const uword start_index,
  1814. const uword end_index,
  1815. Mat<eT>& acc_means,
  1816. Cube<eT>& acc_fcovs,
  1817. Col<eT>& acc_norm_lhoods,
  1818. Col<eT>& gaus_log_lhoods,
  1819. eT& progress_log_lhood
  1820. )
  1821. const
  1822. {
  1823. arma_extra_debug_sigprint();
  1824. progress_log_lhood = eT(0);
  1825. acc_means.zeros();
  1826. acc_fcovs.zeros();
  1827. acc_norm_lhoods.zeros();
  1828. gaus_log_lhoods.zeros();
  1829. const uword N_dims = means.n_rows;
  1830. const uword N_gaus = means.n_cols;
  1831. const eT* log_hefts_mem = log_hefts.memptr();
  1832. eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr();
  1833. for(uword i=start_index; i <= end_index; i++)
  1834. {
  1835. const eT* x = X.colptr(i);
  1836. for(uword g=0; g < N_gaus; ++g)
  1837. {
  1838. gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[g];
  1839. }
  1840. eT log_lhood_sum = gaus_log_lhoods_mem[0];
  1841. for(uword g=1; g < N_gaus; ++g)
  1842. {
  1843. log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]);
  1844. }
  1845. progress_log_lhood += log_lhood_sum;
  1846. for(uword g=0; g < N_gaus; ++g)
  1847. {
  1848. const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum);
  1849. acc_norm_lhoods[g] += norm_lhood;
  1850. eT* acc_mean_mem = acc_means.colptr(g);
  1851. for(uword d=0; d < N_dims; ++d)
  1852. {
  1853. acc_mean_mem[d] += x[d] * norm_lhood;
  1854. }
  1855. Mat<eT>& acc_fcov = access::rw(acc_fcovs).slice(g);
  1856. // specialised version of acc_fcov += norm_lhood * (xx * xx.t());
  1857. for(uword d=0; d < N_dims; ++d)
  1858. {
  1859. const uword dp1 = d+1;
  1860. const eT xd = x[d];
  1861. eT* acc_fcov_col_d = acc_fcov.colptr(d) + d;
  1862. eT* acc_fcov_row_d = &(acc_fcov.at(d,dp1));
  1863. (*acc_fcov_col_d) += norm_lhood * (xd * xd); acc_fcov_col_d++;
  1864. for(uword e=dp1; e < N_dims; ++e)
  1865. {
  1866. const eT val = norm_lhood * (xd * x[e]);
  1867. (*acc_fcov_col_d) += val; acc_fcov_col_d++;
  1868. (*acc_fcov_row_d) += val; acc_fcov_row_d += N_dims;
  1869. }
  1870. }
  1871. }
  1872. }
  1873. progress_log_lhood /= eT((end_index - start_index) + 1);
  1874. }
  1875. template<typename eT>
  1876. inline
  1877. void
  1878. gmm_full<eT>::em_fix_params(const eT var_floor)
  1879. {
  1880. arma_extra_debug_sigprint();
  1881. const uword N_dims = means.n_rows;
  1882. const uword N_gaus = means.n_cols;
  1883. const eT var_ceiling = std::numeric_limits<eT>::max();
  1884. for(uword g=0; g < N_gaus; ++g)
  1885. {
  1886. Mat<eT>& fcov = access::rw(fcovs).slice(g);
  1887. for(uword d=0; d < N_dims; ++d)
  1888. {
  1889. eT& var_val = fcov.at(d,d);
  1890. if(var_val < var_floor ) { var_val = var_floor; }
  1891. else if(var_val > var_ceiling) { var_val = var_ceiling; }
  1892. else if(arma_isnan(var_val) ) { var_val = eT(1); }
  1893. }
  1894. }
  1895. eT* hefts_mem = access::rw(hefts).memptr();
  1896. for(uword g1=0; g1 < N_gaus; ++g1)
  1897. {
  1898. if(hefts_mem[g1] > eT(0))
  1899. {
  1900. const eT* means_colptr_g1 = means.colptr(g1);
  1901. for(uword g2=(g1+1); g2 < N_gaus; ++g2)
  1902. {
  1903. if( (hefts_mem[g2] > eT(0)) && (std::abs(hefts_mem[g1] - hefts_mem[g2]) <= std::numeric_limits<eT>::epsilon()) )
  1904. {
  1905. const eT dist = distance<eT,1>::eval(N_dims, means_colptr_g1, means.colptr(g2), means_colptr_g1);
  1906. if(dist == eT(0)) { hefts_mem[g2] = eT(0); }
  1907. }
  1908. }
  1909. }
  1910. }
  1911. const eT heft_floor = std::numeric_limits<eT>::min();
  1912. const eT heft_initial = eT(1) / eT(N_gaus);
  1913. for(uword i=0; i < N_gaus; ++i)
  1914. {
  1915. eT& heft_val = hefts_mem[i];
  1916. if(heft_val < heft_floor) { heft_val = heft_floor; }
  1917. else if(heft_val > eT(1) ) { heft_val = eT(1); }
  1918. else if(arma_isnan(heft_val) ) { heft_val = heft_initial; }
  1919. }
  1920. const eT heft_sum = accu(hefts);
  1921. if((heft_sum < (eT(1) - Datum<eT>::eps)) || (heft_sum > (eT(1) + Datum<eT>::eps))) { access::rw(hefts) /= heft_sum; }
  1922. }
  1923. } // namespace gmm_priv
  1924. //! @}