gmm_diag_meat.hpp 64 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. //! \addtogroup gmm_diag
  16. //! @{
  17. namespace gmm_priv
  18. {
  19. template<typename eT>
  20. inline
  21. gmm_diag<eT>::~gmm_diag()
  22. {
  23. arma_extra_debug_sigprint_this(this);
  24. arma_type_check(( (is_same_type<eT,float>::value == false) && (is_same_type<eT,double>::value == false) ));
  25. }
  26. template<typename eT>
  27. inline
  28. gmm_diag<eT>::gmm_diag()
  29. {
  30. arma_extra_debug_sigprint_this(this);
  31. }
  32. template<typename eT>
  33. inline
  34. gmm_diag<eT>::gmm_diag(const gmm_diag<eT>& x)
  35. {
  36. arma_extra_debug_sigprint_this(this);
  37. init(x);
  38. }
  39. template<typename eT>
  40. inline
  41. gmm_diag<eT>&
  42. gmm_diag<eT>::operator=(const gmm_diag<eT>& x)
  43. {
  44. arma_extra_debug_sigprint();
  45. init(x);
  46. return *this;
  47. }
  48. template<typename eT>
  49. inline
  50. gmm_diag<eT>::gmm_diag(const gmm_full<eT>& x)
  51. {
  52. arma_extra_debug_sigprint_this(this);
  53. init(x);
  54. }
  55. template<typename eT>
  56. inline
  57. gmm_diag<eT>&
  58. gmm_diag<eT>::operator=(const gmm_full<eT>& x)
  59. {
  60. arma_extra_debug_sigprint();
  61. init(x);
  62. return *this;
  63. }
  64. template<typename eT>
  65. inline
  66. gmm_diag<eT>::gmm_diag(const uword in_n_dims, const uword in_n_gaus)
  67. {
  68. arma_extra_debug_sigprint_this(this);
  69. init(in_n_dims, in_n_gaus);
  70. }
  71. template<typename eT>
  72. inline
  73. void
  74. gmm_diag<eT>::reset()
  75. {
  76. arma_extra_debug_sigprint();
  77. init(0, 0);
  78. }
  79. template<typename eT>
  80. inline
  81. void
  82. gmm_diag<eT>::reset(const uword in_n_dims, const uword in_n_gaus)
  83. {
  84. arma_extra_debug_sigprint();
  85. init(in_n_dims, in_n_gaus);
  86. }
  87. template<typename eT>
  88. template<typename T1, typename T2, typename T3>
  89. inline
  90. void
  91. gmm_diag<eT>::set_params(const Base<eT,T1>& in_means_expr, const Base<eT,T2>& in_dcovs_expr, const Base<eT,T3>& in_hefts_expr)
  92. {
  93. arma_extra_debug_sigprint();
  94. const unwrap<T1> tmp1(in_means_expr.get_ref());
  95. const unwrap<T2> tmp2(in_dcovs_expr.get_ref());
  96. const unwrap<T3> tmp3(in_hefts_expr.get_ref());
  97. const Mat<eT>& in_means = tmp1.M;
  98. const Mat<eT>& in_dcovs = tmp2.M;
  99. const Mat<eT>& in_hefts = tmp3.M;
  100. arma_debug_check
  101. (
  102. (arma::size(in_means) != arma::size(in_dcovs)) || (in_hefts.n_cols != in_means.n_cols) || (in_hefts.n_rows != 1),
  103. "gmm_diag::set_params(): given parameters have inconsistent and/or wrong sizes"
  104. );
  105. arma_debug_check( (in_means.is_finite() == false), "gmm_diag::set_params(): given means have non-finite values" );
  106. arma_debug_check( (in_dcovs.is_finite() == false), "gmm_diag::set_params(): given dcovs have non-finite values" );
  107. arma_debug_check( (in_hefts.is_finite() == false), "gmm_diag::set_params(): given hefts have non-finite values" );
  108. arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_params(): given dcovs have negative or zero values" );
  109. arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_diag::set_params(): given hefts have negative values" );
  110. const eT s = accu(in_hefts);
  111. arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_diag::set_params(): sum of given hefts is not 1" );
  112. access::rw(means) = in_means;
  113. access::rw(dcovs) = in_dcovs;
  114. access::rw(hefts) = in_hefts;
  115. init_constants();
  116. }
  117. template<typename eT>
  118. template<typename T1>
  119. inline
  120. void
  121. gmm_diag<eT>::set_means(const Base<eT,T1>& in_means_expr)
  122. {
  123. arma_extra_debug_sigprint();
  124. const unwrap<T1> tmp(in_means_expr.get_ref());
  125. const Mat<eT>& in_means = tmp.M;
  126. arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_diag::set_means(): given means have incompatible size" );
  127. arma_debug_check( (in_means.is_finite() == false), "gmm_diag::set_means(): given means have non-finite values" );
  128. access::rw(means) = in_means;
  129. }
  130. template<typename eT>
  131. template<typename T1>
  132. inline
  133. void
  134. gmm_diag<eT>::set_dcovs(const Base<eT,T1>& in_dcovs_expr)
  135. {
  136. arma_extra_debug_sigprint();
  137. const unwrap<T1> tmp(in_dcovs_expr.get_ref());
  138. const Mat<eT>& in_dcovs = tmp.M;
  139. arma_debug_check( (arma::size(in_dcovs) != arma::size(dcovs)), "gmm_diag::set_dcovs(): given dcovs have incompatible size" );
  140. arma_debug_check( (in_dcovs.is_finite() == false), "gmm_diag::set_dcovs(): given dcovs have non-finite values" );
  141. arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_dcovs(): given dcovs have negative or zero values" );
  142. access::rw(dcovs) = in_dcovs;
  143. init_constants();
  144. }
  145. template<typename eT>
  146. template<typename T1>
  147. inline
  148. void
  149. gmm_diag<eT>::set_hefts(const Base<eT,T1>& in_hefts_expr)
  150. {
  151. arma_extra_debug_sigprint();
  152. const unwrap<T1> tmp(in_hefts_expr.get_ref());
  153. const Mat<eT>& in_hefts = tmp.M;
  154. arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_diag::set_hefts(): given hefts have incompatible size" );
  155. arma_debug_check( (in_hefts.is_finite() == false), "gmm_diag::set_hefts(): given hefts have non-finite values" );
  156. arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_diag::set_hefts(): given hefts have negative values" );
  157. const eT s = accu(in_hefts);
  158. arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_diag::set_hefts(): sum of given hefts is not 1" );
  159. // make sure all hefts are positive and non-zero
  160. const eT* in_hefts_mem = in_hefts.memptr();
  161. eT* hefts_mem = access::rw(hefts).memptr();
  162. for(uword i=0; i < hefts.n_elem; ++i)
  163. {
  164. hefts_mem[i] = (std::max)( in_hefts_mem[i], std::numeric_limits<eT>::min() );
  165. }
  166. access::rw(hefts) /= accu(hefts);
  167. log_hefts = log(hefts);
  168. }
  169. template<typename eT>
  170. inline
  171. uword
  172. gmm_diag<eT>::n_dims() const
  173. {
  174. return means.n_rows;
  175. }
  176. template<typename eT>
  177. inline
  178. uword
  179. gmm_diag<eT>::n_gaus() const
  180. {
  181. return means.n_cols;
  182. }
  183. template<typename eT>
  184. inline
  185. bool
  186. gmm_diag<eT>::load(const std::string name)
  187. {
  188. arma_extra_debug_sigprint();
  189. Cube<eT> Q;
  190. bool status = Q.load(name, arma_binary);
  191. if( (status == false) || (Q.n_slices != 2) )
  192. {
  193. reset();
  194. arma_debug_warn("gmm_diag::load(): problem with loading or incompatible format");
  195. return false;
  196. }
  197. if( (Q.n_rows < 2) || (Q.n_cols < 1) )
  198. {
  199. reset();
  200. return true;
  201. }
  202. access::rw(hefts) = Q.slice(0).row(0);
  203. access::rw(means) = Q.slice(0).submat(1, 0, Q.n_rows-1, Q.n_cols-1);
  204. access::rw(dcovs) = Q.slice(1).submat(1, 0, Q.n_rows-1, Q.n_cols-1);
  205. init_constants();
  206. return true;
  207. }
  208. template<typename eT>
  209. inline
  210. bool
  211. gmm_diag<eT>::save(const std::string name) const
  212. {
  213. arma_extra_debug_sigprint();
  214. Cube<eT> Q(means.n_rows + 1, means.n_cols, 2);
  215. if(Q.n_elem > 0)
  216. {
  217. Q.slice(0).row(0) = hefts;
  218. Q.slice(1).row(0).zeros(); // reserved for future use
  219. Q.slice(0).submat(1, 0, arma::size(means)) = means;
  220. Q.slice(1).submat(1, 0, arma::size(dcovs)) = dcovs;
  221. }
  222. const bool status = Q.save(name, arma_binary);
  223. return status;
  224. }
  225. template<typename eT>
  226. inline
  227. Col<eT>
  228. gmm_diag<eT>::generate() const
  229. {
  230. arma_extra_debug_sigprint();
  231. const uword N_dims = means.n_rows;
  232. const uword N_gaus = means.n_cols;
  233. Col<eT> out( ((N_gaus > 0) ? N_dims : uword(0)), fill::randn );
  234. if(N_gaus > 0)
  235. {
  236. const double val = randu<double>();
  237. double csum = double(0);
  238. uword gaus_id = 0;
  239. for(uword j=0; j < N_gaus; ++j)
  240. {
  241. csum += hefts[j];
  242. if(val <= csum) { gaus_id = j; break; }
  243. }
  244. out %= sqrt(dcovs.col(gaus_id));
  245. out += means.col(gaus_id);
  246. }
  247. return out;
  248. }
  249. template<typename eT>
  250. inline
  251. Mat<eT>
  252. gmm_diag<eT>::generate(const uword N_vec) const
  253. {
  254. arma_extra_debug_sigprint();
  255. const uword N_dims = means.n_rows;
  256. const uword N_gaus = means.n_cols;
  257. Mat<eT> out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn );
  258. if(N_gaus > 0)
  259. {
  260. const eT* hefts_mem = hefts.memptr();
  261. const Mat<eT> sqrt_dcovs = sqrt(dcovs);
  262. for(uword i=0; i < N_vec; ++i)
  263. {
  264. const double val = randu<double>();
  265. double csum = double(0);
  266. uword gaus_id = 0;
  267. for(uword j=0; j < N_gaus; ++j)
  268. {
  269. csum += hefts_mem[j];
  270. if(val <= csum) { gaus_id = j; break; }
  271. }
  272. subview_col<eT> out_col = out.col(i);
  273. out_col %= sqrt_dcovs.col(gaus_id);
  274. out_col += means.col(gaus_id);
  275. }
  276. }
  277. return out;
  278. }
  279. template<typename eT>
  280. template<typename T1>
  281. inline
  282. eT
  283. gmm_diag<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
  284. {
  285. arma_extra_debug_sigprint();
  286. arma_ignore(junk1);
  287. arma_ignore(junk2);
  288. const quasi_unwrap<T1> tmp(expr);
  289. arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
  290. return internal_scalar_log_p( tmp.M.memptr() );
  291. }
  292. template<typename eT>
  293. template<typename T1>
  294. inline
  295. eT
  296. gmm_diag<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
  297. {
  298. arma_extra_debug_sigprint();
  299. arma_ignore(junk2);
  300. const quasi_unwrap<T1> tmp(expr);
  301. arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
  302. arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" );
  303. return internal_scalar_log_p( tmp.M.memptr(), gaus_id );
  304. }
  305. template<typename eT>
  306. template<typename T1>
  307. inline
  308. Row<eT>
  309. gmm_diag<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
  310. {
  311. arma_extra_debug_sigprint();
  312. arma_ignore(junk1);
  313. arma_ignore(junk2);
  314. const quasi_unwrap<T1> tmp(expr);
  315. const Mat<eT>& X = tmp.M;
  316. return internal_vec_log_p(X);
  317. }
  318. template<typename eT>
  319. template<typename T1>
  320. inline
  321. Row<eT>
  322. gmm_diag<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
  323. {
  324. arma_extra_debug_sigprint();
  325. arma_ignore(junk2);
  326. const quasi_unwrap<T1> tmp(expr);
  327. const Mat<eT>& X = tmp.M;
  328. return internal_vec_log_p(X, gaus_id);
  329. }
  330. template<typename eT>
  331. template<typename T1>
  332. inline
  333. eT
  334. gmm_diag<eT>::sum_log_p(const Base<eT,T1>& expr) const
  335. {
  336. arma_extra_debug_sigprint();
  337. const quasi_unwrap<T1> tmp(expr.get_ref());
  338. const Mat<eT>& X = tmp.M;
  339. return internal_sum_log_p(X);
  340. }
  341. template<typename eT>
  342. template<typename T1>
  343. inline
  344. eT
  345. gmm_diag<eT>::sum_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
  346. {
  347. arma_extra_debug_sigprint();
  348. const quasi_unwrap<T1> tmp(expr.get_ref());
  349. const Mat<eT>& X = tmp.M;
  350. return internal_sum_log_p(X, gaus_id);
  351. }
  352. template<typename eT>
  353. template<typename T1>
  354. inline
  355. eT
  356. gmm_diag<eT>::avg_log_p(const Base<eT,T1>& expr) const
  357. {
  358. arma_extra_debug_sigprint();
  359. const quasi_unwrap<T1> tmp(expr.get_ref());
  360. const Mat<eT>& X = tmp.M;
  361. return internal_avg_log_p(X);
  362. }
  363. template<typename eT>
  364. template<typename T1>
  365. inline
  366. eT
  367. gmm_diag<eT>::avg_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
  368. {
  369. arma_extra_debug_sigprint();
  370. const quasi_unwrap<T1> tmp(expr.get_ref());
  371. const Mat<eT>& X = tmp.M;
  372. return internal_avg_log_p(X, gaus_id);
  373. }
  374. template<typename eT>
  375. template<typename T1>
  376. inline
  377. uword
  378. gmm_diag<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk) const
  379. {
  380. arma_extra_debug_sigprint();
  381. arma_ignore(junk);
  382. const quasi_unwrap<T1> tmp(expr);
  383. const Mat<eT>& X = tmp.M;
  384. return internal_scalar_assign(X, dist);
  385. }
  386. template<typename eT>
  387. template<typename T1>
  388. inline
  389. urowvec
  390. gmm_diag<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk) const
  391. {
  392. arma_extra_debug_sigprint();
  393. arma_ignore(junk);
  394. urowvec out;
  395. const quasi_unwrap<T1> tmp(expr);
  396. const Mat<eT>& X = tmp.M;
  397. internal_vec_assign(out, X, dist);
  398. return out;
  399. }
  400. template<typename eT>
  401. template<typename T1>
  402. inline
  403. urowvec
  404. gmm_diag<eT>::raw_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
  405. {
  406. arma_extra_debug_sigprint();
  407. const unwrap<T1> tmp(expr.get_ref());
  408. const Mat<eT>& X = tmp.M;
  409. arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::raw_hist(): incompatible dimensions" );
  410. arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::raw_hist(): unsupported distance mode" );
  411. urowvec hist;
  412. internal_raw_hist(hist, X, dist_mode);
  413. return hist;
  414. }
  415. template<typename eT>
  416. template<typename T1>
  417. inline
  418. Row<eT>
  419. gmm_diag<eT>::norm_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
  420. {
  421. arma_extra_debug_sigprint();
  422. const unwrap<T1> tmp(expr.get_ref());
  423. const Mat<eT>& X = tmp.M;
  424. arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::norm_hist(): incompatible dimensions" );
  425. arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::norm_hist(): unsupported distance mode" );
  426. urowvec hist;
  427. internal_raw_hist(hist, X, dist_mode);
  428. const uword hist_n_elem = hist.n_elem;
  429. const uword* hist_mem = hist.memptr();
  430. eT acc = eT(0);
  431. for(uword i=0; i<hist_n_elem; ++i) { acc += eT(hist_mem[i]); }
  432. if(acc == eT(0)) { acc = eT(1); }
  433. Row<eT> out(hist_n_elem);
  434. eT* out_mem = out.memptr();
  435. for(uword i=0; i<hist_n_elem; ++i) { out_mem[i] = eT(hist_mem[i]) / acc; }
  436. return out;
  437. }
  438. template<typename eT>
  439. template<typename T1>
  440. inline
  441. bool
  442. gmm_diag<eT>::learn
  443. (
  444. const Base<eT,T1>& data,
  445. const uword N_gaus,
  446. const gmm_dist_mode& dist_mode,
  447. const gmm_seed_mode& seed_mode,
  448. const uword km_iter,
  449. const uword em_iter,
  450. const eT var_floor,
  451. const bool print_mode
  452. )
  453. {
  454. arma_extra_debug_sigprint();
  455. const bool dist_mode_ok = (dist_mode == eucl_dist) || (dist_mode == maha_dist);
  456. const bool seed_mode_ok = \
  457. (seed_mode == keep_existing)
  458. || (seed_mode == static_subset)
  459. || (seed_mode == static_spread)
  460. || (seed_mode == random_subset)
  461. || (seed_mode == random_spread);
  462. arma_debug_check( (dist_mode_ok == false), "gmm_diag::learn(): dist_mode must be eucl_dist or maha_dist" );
  463. arma_debug_check( (seed_mode_ok == false), "gmm_diag::learn(): unknown seed_mode" );
  464. arma_debug_check( (var_floor < eT(0) ), "gmm_diag::learn(): variance floor is negative" );
  465. const unwrap<T1> tmp_X(data.get_ref());
  466. const Mat<eT>& X = tmp_X.M;
  467. if(X.is_empty() ) { arma_debug_warn("gmm_diag::learn(): given matrix is empty" ); return false; }
  468. if(X.is_finite() == false) { arma_debug_warn("gmm_diag::learn(): given matrix has non-finite values"); return false; }
  469. if(N_gaus == 0) { reset(); return true; }
  470. if(dist_mode == maha_dist)
  471. {
  472. mah_aux = var(X,1,1);
  473. const uword mah_aux_n_elem = mah_aux.n_elem;
  474. eT* mah_aux_mem = mah_aux.memptr();
  475. for(uword i=0; i < mah_aux_n_elem; ++i)
  476. {
  477. const eT val = mah_aux_mem[i];
  478. mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1);
  479. }
  480. }
  481. // copy current model, in case of failure by k-means and/or EM
  482. const gmm_diag<eT> orig = (*this);
  483. // initial means
  484. if(seed_mode == keep_existing)
  485. {
  486. if(means.is_empty() ) { arma_debug_warn("gmm_diag::learn(): no existing means" ); return false; }
  487. if(X.n_rows != means.n_rows) { arma_debug_warn("gmm_diag::learn(): dimensionality mismatch"); return false; }
  488. // TODO: also check for number of vectors?
  489. }
  490. else
  491. {
  492. if(X.n_cols < N_gaus) { arma_debug_warn("gmm_diag::learn(): number of vectors is less than number of gaussians"); return false; }
  493. reset(X.n_rows, N_gaus);
  494. if(print_mode) { get_cout_stream() << "gmm_diag::learn(): generating initial means\n"; get_cout_stream().flush(); }
  495. if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mode); }
  496. else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mode); }
  497. }
  498. // k-means
  499. if(km_iter > 0)
  500. {
  501. const arma_ostream_state stream_state(get_cout_stream());
  502. bool status = false;
  503. if(dist_mode == eucl_dist) { status = km_iterate<1>(X, km_iter, print_mode, "gmm_diag::learn(): k-means"); }
  504. else if(dist_mode == maha_dist) { status = km_iterate<2>(X, km_iter, print_mode, "gmm_diag::learn(): k-means"); }
  505. stream_state.restore(get_cout_stream());
  506. if(status == false) { arma_debug_warn("gmm_diag::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; }
  507. }
  508. // initial dcovs
  509. const eT var_floor_actual = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_limits<eT>::min();
  510. if(seed_mode != keep_existing)
  511. {
  512. if(print_mode) { get_cout_stream() << "gmm_diag::learn(): generating initial covariances\n"; get_cout_stream().flush(); }
  513. if(dist_mode == eucl_dist) { generate_initial_params<1>(X, var_floor_actual); }
  514. else if(dist_mode == maha_dist) { generate_initial_params<2>(X, var_floor_actual); }
  515. }
  516. // EM algorithm
  517. if(em_iter > 0)
  518. {
  519. const arma_ostream_state stream_state(get_cout_stream());
  520. const bool status = em_iterate(X, em_iter, var_floor_actual, print_mode);
  521. stream_state.restore(get_cout_stream());
  522. if(status == false) { arma_debug_warn("gmm_diag::learn(): EM algorithm failed"); init(orig); return false; }
  523. }
  524. mah_aux.reset();
  525. init_constants();
  526. return true;
  527. }
  528. template<typename eT>
  529. template<typename T1>
  530. inline
  531. bool
  532. gmm_diag<eT>::kmeans_wrapper
  533. (
  534. Mat<eT>& user_means,
  535. const Base<eT,T1>& data,
  536. const uword N_gaus,
  537. const gmm_seed_mode& seed_mode,
  538. const uword km_iter,
  539. const bool print_mode
  540. )
  541. {
  542. arma_extra_debug_sigprint();
  543. const bool seed_mode_ok = \
  544. (seed_mode == keep_existing)
  545. || (seed_mode == static_subset)
  546. || (seed_mode == static_spread)
  547. || (seed_mode == random_subset)
  548. || (seed_mode == random_spread);
  549. arma_debug_check( (seed_mode_ok == false), "kmeans(): unknown seed_mode" );
  550. const unwrap<T1> tmp_X(data.get_ref());
  551. const Mat<eT>& X = tmp_X.M;
  552. if(X.is_empty() ) { arma_debug_warn("kmeans(): given matrix is empty" ); return false; }
  553. if(X.is_finite() == false) { arma_debug_warn("kmeans(): given matrix has non-finite values"); return false; }
  554. if(N_gaus == 0) { reset(); return true; }
  555. // initial means
  556. if(seed_mode == keep_existing)
  557. {
  558. access::rw(means) = user_means;
  559. if(means.is_empty() ) { arma_debug_warn("kmeans(): no existing means" ); return false; }
  560. if(X.n_rows != means.n_rows) { arma_debug_warn("kmeans(): dimensionality mismatch"); return false; }
  561. // TODO: also check for number of vectors?
  562. }
  563. else
  564. {
  565. if(X.n_cols < N_gaus) { arma_debug_warn("kmeans(): number of vectors is less than number of means"); return false; }
  566. access::rw(means).zeros(X.n_rows, N_gaus);
  567. if(print_mode) { get_cout_stream() << "kmeans(): generating initial means\n"; }
  568. generate_initial_means<1>(X, seed_mode);
  569. }
  570. // k-means
  571. if(km_iter > 0)
  572. {
  573. const arma_ostream_state stream_state(get_cout_stream());
  574. bool status = false;
  575. status = km_iterate<1>(X, km_iter, print_mode, "kmeans()");
  576. stream_state.restore(get_cout_stream());
  577. if(status == false) { arma_debug_warn("kmeans(): clustering failed; not enough data, or too many means requested"); return false; }
  578. }
  579. return true;
  580. }
  581. //
  582. //
  583. //
  584. template<typename eT>
  585. inline
  586. void
  587. gmm_diag<eT>::init(const gmm_diag<eT>& x)
  588. {
  589. arma_extra_debug_sigprint();
  590. gmm_diag<eT>& t = *this;
  591. if(&t != &x)
  592. {
  593. access::rw(t.means) = x.means;
  594. access::rw(t.dcovs) = x.dcovs;
  595. access::rw(t.hefts) = x.hefts;
  596. init_constants();
  597. }
  598. }
  599. template<typename eT>
  600. inline
  601. void
  602. gmm_diag<eT>::init(const gmm_full<eT>& x)
  603. {
  604. arma_extra_debug_sigprint();
  605. access::rw(hefts) = x.hefts;
  606. access::rw(means) = x.means;
  607. const uword N_dims = x.means.n_rows;
  608. const uword N_gaus = x.means.n_cols;
  609. access::rw(dcovs).zeros(N_dims,N_gaus);
  610. for(uword g=0; g < N_gaus; ++g)
  611. {
  612. const Mat<eT>& fcov = x.fcovs.slice(g);
  613. eT* dcov_mem = access::rw(dcovs).colptr(g);
  614. for(uword d=0; d < N_dims; ++d)
  615. {
  616. dcov_mem[d] = fcov.at(d,d);
  617. }
  618. }
  619. init_constants();
  620. }
  621. template<typename eT>
  622. inline
  623. void
  624. gmm_diag<eT>::init(const uword in_n_dims, const uword in_n_gaus)
  625. {
  626. arma_extra_debug_sigprint();
  627. access::rw(means).zeros(in_n_dims, in_n_gaus);
  628. access::rw(dcovs).ones(in_n_dims, in_n_gaus);
  629. access::rw(hefts).set_size(in_n_gaus);
  630. access::rw(hefts).fill(eT(1) / eT(in_n_gaus));
  631. init_constants();
  632. }
  633. template<typename eT>
  634. inline
  635. void
  636. gmm_diag<eT>::init_constants()
  637. {
  638. arma_extra_debug_sigprint();
  639. const uword N_dims = means.n_rows;
  640. const uword N_gaus = means.n_cols;
  641. //
  642. inv_dcovs.copy_size(dcovs);
  643. const eT* dcovs_mem = dcovs.memptr();
  644. eT* inv_dcovs_mem = inv_dcovs.memptr();
  645. const uword dcovs_n_elem = dcovs.n_elem;
  646. for(uword i=0; i < dcovs_n_elem; ++i)
  647. {
  648. inv_dcovs_mem[i] = eT(1) / (std::max)( dcovs_mem[i], std::numeric_limits<eT>::min() );
  649. }
  650. //
  651. const eT tmp = (eT(N_dims)/eT(2)) * std::log(eT(2) * Datum<eT>::pi);
  652. log_det_etc.set_size(N_gaus);
  653. for(uword g=0; g < N_gaus; ++g)
  654. {
  655. const eT* dcovs_colmem = dcovs.colptr(g);
  656. eT log_det_val = eT(0);
  657. for(uword d=0; d < N_dims; ++d)
  658. {
  659. log_det_val += std::log( (std::max)( dcovs_colmem[d], std::numeric_limits<eT>::min() ) );
  660. }
  661. log_det_etc[g] = eT(-1) * ( tmp + eT(0.5) * log_det_val );
  662. }
  663. //
  664. eT* hefts_mem = access::rw(hefts).memptr();
  665. for(uword g=0; g < N_gaus; ++g)
  666. {
  667. hefts_mem[g] = (std::max)( hefts_mem[g], std::numeric_limits<eT>::min() );
  668. }
  669. log_hefts = log(hefts);
  670. }
  671. template<typename eT>
  672. inline
  673. umat
  674. gmm_diag<eT>::internal_gen_boundaries(const uword N) const
  675. {
  676. arma_extra_debug_sigprint();
  677. #if defined(ARMA_USE_OPENMP)
  678. const uword n_threads_avail = (omp_in_parallel()) ? uword(1) : uword(omp_get_max_threads());
  679. const uword n_threads = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1;
  680. #else
  681. static const uword n_threads = 1;
  682. #endif
  683. // get_cout_stream() << "gmm_diag::internal_gen_boundaries(): n_threads: " << n_threads << '\n';
  684. umat boundaries(2, n_threads);
  685. if(N > 0)
  686. {
  687. const uword chunk_size = N / n_threads;
  688. uword count = 0;
  689. for(uword t=0; t<n_threads; t++)
  690. {
  691. boundaries.at(0,t) = count;
  692. count += chunk_size;
  693. boundaries.at(1,t) = count-1;
  694. }
  695. boundaries.at(1,n_threads-1) = N - 1;
  696. }
  697. else
  698. {
  699. boundaries.zeros();
  700. }
  701. // get_cout_stream() << "gmm_diag::internal_gen_boundaries(): boundaries: " << '\n' << boundaries << '\n';
  702. return boundaries;
  703. }
  704. template<typename eT>
  705. arma_hot
  706. inline
  707. eT
  708. gmm_diag<eT>::internal_scalar_log_p(const eT* x) const
  709. {
  710. arma_extra_debug_sigprint();
  711. const eT* log_hefts_mem = log_hefts.mem;
  712. const uword N_gaus = means.n_cols;
  713. if(N_gaus > 0)
  714. {
  715. eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0];
  716. for(uword g=1; g < N_gaus; ++g)
  717. {
  718. const eT tmp = internal_scalar_log_p(x, g) + log_hefts_mem[g];
  719. log_sum = log_add_exp(log_sum, tmp);
  720. }
  721. return log_sum;
  722. }
  723. else
  724. {
  725. return -Datum<eT>::inf;
  726. }
  727. }
  728. template<typename eT>
  729. arma_hot
  730. inline
  731. eT
  732. gmm_diag<eT>::internal_scalar_log_p(const eT* x, const uword g) const
  733. {
  734. arma_extra_debug_sigprint();
  735. const eT* mean = means.colptr(g);
  736. const eT* inv_dcov = inv_dcovs.colptr(g);
  737. const uword N_dims = means.n_rows;
  738. eT val_i = eT(0);
  739. eT val_j = eT(0);
  740. uword i,j;
  741. for(i=0, j=1; j<N_dims; i+=2, j+=2)
  742. {
  743. eT tmp_i = x[i];
  744. eT tmp_j = x[j];
  745. tmp_i -= mean[i];
  746. tmp_j -= mean[j];
  747. val_i += (tmp_i*tmp_i) * inv_dcov[i];
  748. val_j += (tmp_j*tmp_j) * inv_dcov[j];
  749. }
  750. if(i < N_dims)
  751. {
  752. const eT tmp = x[i] - mean[i];
  753. val_i += (tmp*tmp) * inv_dcov[i];
  754. }
  755. return eT(-0.5)*(val_i + val_j) + log_det_etc.mem[g];
  756. }
  757. template<typename eT>
  758. inline
  759. Row<eT>
  760. gmm_diag<eT>::internal_vec_log_p(const Mat<eT>& X) const
  761. {
  762. arma_extra_debug_sigprint();
  763. arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
  764. const uword N = X.n_cols;
  765. Row<eT> out(N);
  766. if(N > 0)
  767. {
  768. #if defined(ARMA_USE_OPENMP)
  769. {
  770. const umat boundaries = internal_gen_boundaries(N);
  771. const uword n_threads = boundaries.n_cols;
  772. #pragma omp parallel for schedule(static)
  773. for(uword t=0; t < n_threads; ++t)
  774. {
  775. const uword start_index = boundaries.at(0,t);
  776. const uword end_index = boundaries.at(1,t);
  777. eT* out_mem = out.memptr();
  778. for(uword i=start_index; i <= end_index; ++i)
  779. {
  780. out_mem[i] = internal_scalar_log_p( X.colptr(i) );
  781. }
  782. }
  783. }
  784. #else
  785. {
  786. eT* out_mem = out.memptr();
  787. for(uword i=0; i < N; ++i)
  788. {
  789. out_mem[i] = internal_scalar_log_p( X.colptr(i) );
  790. }
  791. }
  792. #endif
  793. }
  794. return out;
  795. }
  796. template<typename eT>
  797. inline
  798. Row<eT>
  799. gmm_diag<eT>::internal_vec_log_p(const Mat<eT>& X, const uword gaus_id) const
  800. {
  801. arma_extra_debug_sigprint();
  802. arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
  803. arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" );
  804. const uword N = X.n_cols;
  805. Row<eT> out(N);
  806. if(N > 0)
  807. {
  808. #if defined(ARMA_USE_OPENMP)
  809. {
  810. const umat boundaries = internal_gen_boundaries(N);
  811. const uword n_threads = boundaries.n_cols;
  812. #pragma omp parallel for schedule(static)
  813. for(uword t=0; t < n_threads; ++t)
  814. {
  815. const uword start_index = boundaries.at(0,t);
  816. const uword end_index = boundaries.at(1,t);
  817. eT* out_mem = out.memptr();
  818. for(uword i=start_index; i <= end_index; ++i)
  819. {
  820. out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
  821. }
  822. }
  823. }
  824. #else
  825. {
  826. eT* out_mem = out.memptr();
  827. for(uword i=0; i < N; ++i)
  828. {
  829. out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
  830. }
  831. }
  832. #endif
  833. }
  834. return out;
  835. }
  836. template<typename eT>
  837. inline
  838. eT
  839. gmm_diag<eT>::internal_sum_log_p(const Mat<eT>& X) const
  840. {
  841. arma_extra_debug_sigprint();
  842. arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::sum_log_p(): incompatible dimensions" );
  843. const uword N = X.n_cols;
  844. if(N == 0) { return (-Datum<eT>::inf); }
  845. #if defined(ARMA_USE_OPENMP)
  846. {
  847. const umat boundaries = internal_gen_boundaries(N);
  848. const uword n_threads = boundaries.n_cols;
  849. Col<eT> t_accs(n_threads, fill::zeros);
  850. #pragma omp parallel for schedule(static)
  851. for(uword t=0; t < n_threads; ++t)
  852. {
  853. const uword start_index = boundaries.at(0,t);
  854. const uword end_index = boundaries.at(1,t);
  855. eT t_acc = eT(0);
  856. for(uword i=start_index; i <= end_index; ++i)
  857. {
  858. t_acc += internal_scalar_log_p( X.colptr(i) );
  859. }
  860. t_accs[t] = t_acc;
  861. }
  862. return eT(accu(t_accs));
  863. }
  864. #else
  865. {
  866. eT acc = eT(0);
  867. for(uword i=0; i<N; ++i)
  868. {
  869. acc += internal_scalar_log_p( X.colptr(i) );
  870. }
  871. return acc;
  872. }
  873. #endif
  874. }
  875. template<typename eT>
  876. inline
  877. eT
  878. gmm_diag<eT>::internal_sum_log_p(const Mat<eT>& X, const uword gaus_id) const
  879. {
  880. arma_extra_debug_sigprint();
  881. arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::sum_log_p(): incompatible dimensions" );
  882. arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::sum_log_p(): specified gaussian is out of range" );
  883. const uword N = X.n_cols;
  884. if(N == 0) { return (-Datum<eT>::inf); }
  885. #if defined(ARMA_USE_OPENMP)
  886. {
  887. const umat boundaries = internal_gen_boundaries(N);
  888. const uword n_threads = boundaries.n_cols;
  889. Col<eT> t_accs(n_threads, fill::zeros);
  890. #pragma omp parallel for schedule(static)
  891. for(uword t=0; t < n_threads; ++t)
  892. {
  893. const uword start_index = boundaries.at(0,t);
  894. const uword end_index = boundaries.at(1,t);
  895. eT t_acc = eT(0);
  896. for(uword i=start_index; i <= end_index; ++i)
  897. {
  898. t_acc += internal_scalar_log_p( X.colptr(i), gaus_id );
  899. }
  900. t_accs[t] = t_acc;
  901. }
  902. return eT(accu(t_accs));
  903. }
  904. #else
  905. {
  906. eT acc = eT(0);
  907. for(uword i=0; i<N; ++i)
  908. {
  909. acc += internal_scalar_log_p( X.colptr(i), gaus_id );
  910. }
  911. return acc;
  912. }
  913. #endif
  914. }
  915. template<typename eT>
  916. inline
  917. eT
  918. gmm_diag<eT>::internal_avg_log_p(const Mat<eT>& X) const
  919. {
  920. arma_extra_debug_sigprint();
  921. arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions" );
  922. const uword N = X.n_cols;
  923. if(N == 0) { return (-Datum<eT>::inf); }
  924. #if defined(ARMA_USE_OPENMP)
  925. {
  926. const umat boundaries = internal_gen_boundaries(N);
  927. const uword n_threads = boundaries.n_cols;
  928. field< running_mean_scalar<eT> > t_running_means(n_threads);
  929. #pragma omp parallel for schedule(static)
  930. for(uword t=0; t < n_threads; ++t)
  931. {
  932. const uword start_index = boundaries.at(0,t);
  933. const uword end_index = boundaries.at(1,t);
  934. running_mean_scalar<eT>& current_running_mean = t_running_means[t];
  935. for(uword i=start_index; i <= end_index; ++i)
  936. {
  937. current_running_mean( internal_scalar_log_p( X.colptr(i) ) );
  938. }
  939. }
  940. eT avg = eT(0);
  941. for(uword t=0; t < n_threads; ++t)
  942. {
  943. running_mean_scalar<eT>& current_running_mean = t_running_means[t];
  944. const eT w = eT(current_running_mean.count()) / eT(N);
  945. avg += w * current_running_mean.mean();
  946. }
  947. return avg;
  948. }
  949. #else
  950. {
  951. running_mean_scalar<eT> running_mean;
  952. for(uword i=0; i<N; ++i)
  953. {
  954. running_mean( internal_scalar_log_p( X.colptr(i) ) );
  955. }
  956. return running_mean.mean();
  957. }
  958. #endif
  959. }
  960. template<typename eT>
  961. inline
  962. eT
  963. gmm_diag<eT>::internal_avg_log_p(const Mat<eT>& X, const uword gaus_id) const
  964. {
  965. arma_extra_debug_sigprint();
  966. arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions" );
  967. arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::avg_log_p(): specified gaussian is out of range" );
  968. const uword N = X.n_cols;
  969. if(N == 0) { return (-Datum<eT>::inf); }
  970. #if defined(ARMA_USE_OPENMP)
  971. {
  972. const umat boundaries = internal_gen_boundaries(N);
  973. const uword n_threads = boundaries.n_cols;
  974. field< running_mean_scalar<eT> > t_running_means(n_threads);
  975. #pragma omp parallel for schedule(static)
  976. for(uword t=0; t < n_threads; ++t)
  977. {
  978. const uword start_index = boundaries.at(0,t);
  979. const uword end_index = boundaries.at(1,t);
  980. running_mean_scalar<eT>& current_running_mean = t_running_means[t];
  981. for(uword i=start_index; i <= end_index; ++i)
  982. {
  983. current_running_mean( internal_scalar_log_p( X.colptr(i), gaus_id) );
  984. }
  985. }
  986. eT avg = eT(0);
  987. for(uword t=0; t < n_threads; ++t)
  988. {
  989. running_mean_scalar<eT>& current_running_mean = t_running_means[t];
  990. const eT w = eT(current_running_mean.count()) / eT(N);
  991. avg += w * current_running_mean.mean();
  992. }
  993. return avg;
  994. }
  995. #else
  996. {
  997. running_mean_scalar<eT> running_mean;
  998. for(uword i=0; i<N; ++i)
  999. {
  1000. running_mean( internal_scalar_log_p( X.colptr(i), gaus_id ) );
  1001. }
  1002. return running_mean.mean();
  1003. }
  1004. #endif
  1005. }
  1006. template<typename eT>
  1007. inline
  1008. uword
  1009. gmm_diag<eT>::internal_scalar_assign(const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
  1010. {
  1011. arma_extra_debug_sigprint();
  1012. const uword N_dims = means.n_rows;
  1013. const uword N_gaus = means.n_cols;
  1014. arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" );
  1015. arma_debug_check( (N_gaus == 0), "gmm_diag::assign(): model has no means" );
  1016. const eT* X_mem = X.colptr(0);
  1017. if(dist_mode == eucl_dist)
  1018. {
  1019. eT best_dist = Datum<eT>::inf;
  1020. uword best_g = 0;
  1021. for(uword g=0; g < N_gaus; ++g)
  1022. {
  1023. const eT tmp_dist = distance<eT,1>::eval(N_dims, X_mem, means.colptr(g), X_mem);
  1024. if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
  1025. }
  1026. return best_g;
  1027. }
  1028. else
  1029. if(dist_mode == prob_dist)
  1030. {
  1031. const eT* log_hefts_mem = log_hefts.memptr();
  1032. eT best_p = -Datum<eT>::inf;
  1033. uword best_g = 0;
  1034. for(uword g=0; g < N_gaus; ++g)
  1035. {
  1036. const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g];
  1037. if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
  1038. }
  1039. return best_g;
  1040. }
  1041. else
  1042. {
  1043. arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode");
  1044. }
  1045. return uword(0);
  1046. }
  1047. template<typename eT>
  1048. inline
  1049. void
  1050. gmm_diag<eT>::internal_vec_assign(urowvec& out, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
  1051. {
  1052. arma_extra_debug_sigprint();
  1053. const uword N_dims = means.n_rows;
  1054. const uword N_gaus = means.n_cols;
  1055. arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" );
  1056. const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0;
  1057. out.set_size(1,X_n_cols);
  1058. uword* out_mem = out.memptr();
  1059. if(dist_mode == eucl_dist)
  1060. {
  1061. #if defined(ARMA_USE_OPENMP)
  1062. {
  1063. #pragma omp parallel for schedule(static)
  1064. for(uword i=0; i<X_n_cols; ++i)
  1065. {
  1066. const eT* X_colptr = X.colptr(i);
  1067. eT best_dist = Datum<eT>::inf;
  1068. uword best_g = 0;
  1069. for(uword g=0; g<N_gaus; ++g)
  1070. {
  1071. const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
  1072. if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
  1073. }
  1074. out_mem[i] = best_g;
  1075. }
  1076. }
  1077. #else
  1078. {
  1079. for(uword i=0; i<X_n_cols; ++i)
  1080. {
  1081. const eT* X_colptr = X.colptr(i);
  1082. eT best_dist = Datum<eT>::inf;
  1083. uword best_g = 0;
  1084. for(uword g=0; g<N_gaus; ++g)
  1085. {
  1086. const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
  1087. if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
  1088. }
  1089. out_mem[i] = best_g;
  1090. }
  1091. }
  1092. #endif
  1093. }
  1094. else
  1095. if(dist_mode == prob_dist)
  1096. {
  1097. #if defined(ARMA_USE_OPENMP)
  1098. {
  1099. const eT* log_hefts_mem = log_hefts.memptr();
  1100. #pragma omp parallel for schedule(static)
  1101. for(uword i=0; i<X_n_cols; ++i)
  1102. {
  1103. const eT* X_colptr = X.colptr(i);
  1104. eT best_p = -Datum<eT>::inf;
  1105. uword best_g = 0;
  1106. for(uword g=0; g<N_gaus; ++g)
  1107. {
  1108. const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
  1109. if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
  1110. }
  1111. out_mem[i] = best_g;
  1112. }
  1113. }
  1114. #else
  1115. {
  1116. const eT* log_hefts_mem = log_hefts.memptr();
  1117. for(uword i=0; i<X_n_cols; ++i)
  1118. {
  1119. const eT* X_colptr = X.colptr(i);
  1120. eT best_p = -Datum<eT>::inf;
  1121. uword best_g = 0;
  1122. for(uword g=0; g<N_gaus; ++g)
  1123. {
  1124. const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
  1125. if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
  1126. }
  1127. out_mem[i] = best_g;
  1128. }
  1129. }
  1130. #endif
  1131. }
  1132. else
  1133. {
  1134. arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode");
  1135. }
  1136. }
  1137. template<typename eT>
  1138. inline
  1139. void
  1140. gmm_diag<eT>::internal_raw_hist(urowvec& hist, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
  1141. {
  1142. arma_extra_debug_sigprint();
  1143. const uword N_dims = means.n_rows;
  1144. const uword N_gaus = means.n_cols;
  1145. const uword X_n_cols = X.n_cols;
  1146. hist.zeros(N_gaus);
  1147. if(N_gaus == 0) { return; }
  1148. #if defined(ARMA_USE_OPENMP)
  1149. {
  1150. const umat boundaries = internal_gen_boundaries(X_n_cols);
  1151. const uword n_threads = boundaries.n_cols;
  1152. field<urowvec> thread_hist(n_threads);
  1153. for(uword t=0; t < n_threads; ++t) { thread_hist(t).zeros(N_gaus); }
  1154. if(dist_mode == eucl_dist)
  1155. {
  1156. #pragma omp parallel for schedule(static)
  1157. for(uword t=0; t < n_threads; ++t)
  1158. {
  1159. uword* thread_hist_mem = thread_hist(t).memptr();
  1160. const uword start_index = boundaries.at(0,t);
  1161. const uword end_index = boundaries.at(1,t);
  1162. for(uword i=start_index; i <= end_index; ++i)
  1163. {
  1164. const eT* X_colptr = X.colptr(i);
  1165. eT best_dist = Datum<eT>::inf;
  1166. uword best_g = 0;
  1167. for(uword g=0; g < N_gaus; ++g)
  1168. {
  1169. const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
  1170. if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
  1171. }
  1172. thread_hist_mem[best_g]++;
  1173. }
  1174. }
  1175. }
  1176. else
  1177. if(dist_mode == prob_dist)
  1178. {
  1179. const eT* log_hefts_mem = log_hefts.memptr();
  1180. #pragma omp parallel for schedule(static)
  1181. for(uword t=0; t < n_threads; ++t)
  1182. {
  1183. uword* thread_hist_mem = thread_hist(t).memptr();
  1184. const uword start_index = boundaries.at(0,t);
  1185. const uword end_index = boundaries.at(1,t);
  1186. for(uword i=start_index; i <= end_index; ++i)
  1187. {
  1188. const eT* X_colptr = X.colptr(i);
  1189. eT best_p = -Datum<eT>::inf;
  1190. uword best_g = 0;
  1191. for(uword g=0; g < N_gaus; ++g)
  1192. {
  1193. const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
  1194. if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
  1195. }
  1196. thread_hist_mem[best_g]++;
  1197. }
  1198. }
  1199. }
  1200. // reduction
  1201. hist = thread_hist(0);
  1202. for(uword t=1; t < n_threads; ++t)
  1203. {
  1204. hist += thread_hist(t);
  1205. }
  1206. }
  1207. #else
  1208. {
  1209. uword* hist_mem = hist.memptr();
  1210. if(dist_mode == eucl_dist)
  1211. {
  1212. for(uword i=0; i<X_n_cols; ++i)
  1213. {
  1214. const eT* X_colptr = X.colptr(i);
  1215. eT best_dist = Datum<eT>::inf;
  1216. uword best_g = 0;
  1217. for(uword g=0; g < N_gaus; ++g)
  1218. {
  1219. const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
  1220. if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
  1221. }
  1222. hist_mem[best_g]++;
  1223. }
  1224. }
  1225. else
  1226. if(dist_mode == prob_dist)
  1227. {
  1228. const eT* log_hefts_mem = log_hefts.memptr();
  1229. for(uword i=0; i<X_n_cols; ++i)
  1230. {
  1231. const eT* X_colptr = X.colptr(i);
  1232. eT best_p = -Datum<eT>::inf;
  1233. uword best_g = 0;
  1234. for(uword g=0; g < N_gaus; ++g)
  1235. {
  1236. const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
  1237. if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
  1238. }
  1239. hist_mem[best_g]++;
  1240. }
  1241. }
  1242. }
  1243. #endif
  1244. }
  1245. template<typename eT>
  1246. template<uword dist_id>
  1247. inline
  1248. void
  1249. gmm_diag<eT>::generate_initial_means(const Mat<eT>& X, const gmm_seed_mode& seed_mode)
  1250. {
  1251. arma_extra_debug_sigprint();
  1252. const uword N_dims = means.n_rows;
  1253. const uword N_gaus = means.n_cols;
  1254. if( (seed_mode == static_subset) || (seed_mode == random_subset) )
  1255. {
  1256. uvec initial_indices;
  1257. if(seed_mode == static_subset) { initial_indices = linspace<uvec>(0, X.n_cols-1, N_gaus); }
  1258. else if(seed_mode == random_subset) { initial_indices = randperm<uvec>(X.n_cols, N_gaus); }
  1259. // initial_indices.print("initial_indices:");
  1260. access::rw(means) = X.cols(initial_indices);
  1261. }
  1262. else
  1263. if( (seed_mode == static_spread) || (seed_mode == random_spread) )
  1264. {
  1265. // going through all of the samples can be extremely time consuming;
  1266. // instead, if there are enough samples, randomly choose samples with probability 0.1
  1267. const bool use_sampling = ((X.n_cols/uword(100)) > N_gaus);
  1268. const uword step = (use_sampling) ? uword(10) : uword(1);
  1269. uword start_index = 0;
  1270. if(seed_mode == static_spread) { start_index = X.n_cols / 2; }
  1271. else if(seed_mode == random_spread) { start_index = as_scalar(randi<uvec>(1, distr_param(0,X.n_cols-1))); }
  1272. access::rw(means).col(0) = X.unsafe_col(start_index);
  1273. const eT* mah_aux_mem = mah_aux.memptr();
  1274. running_stat<double> rs;
  1275. for(uword g=1; g < N_gaus; ++g)
  1276. {
  1277. eT max_dist = eT(0);
  1278. uword best_i = uword(0);
  1279. uword start_i = uword(0);
  1280. if(use_sampling)
  1281. {
  1282. uword start_i_proposed = uword(0);
  1283. if(seed_mode == static_spread) { start_i_proposed = g % uword(10); }
  1284. if(seed_mode == random_spread) { start_i_proposed = as_scalar(randi<uvec>(1, distr_param(0,9))); }
  1285. if(start_i_proposed < X.n_cols) { start_i = start_i_proposed; }
  1286. }
  1287. for(uword i=start_i; i < X.n_cols; i += step)
  1288. {
  1289. rs.reset();
  1290. const eT* X_colptr = X.colptr(i);
  1291. bool ignore_i = false;
  1292. // find the average distance between sample i and the means so far
  1293. for(uword h = 0; h < g; ++h)
  1294. {
  1295. const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(h), mah_aux_mem);
  1296. // ignore sample already selected as a mean
  1297. if(dist == eT(0)) { ignore_i = true; break; }
  1298. else { rs(dist); }
  1299. }
  1300. if( (rs.mean() >= max_dist) && (ignore_i == false))
  1301. {
  1302. max_dist = eT(rs.mean()); best_i = i;
  1303. }
  1304. }
  1305. // set the mean to the sample that is the furthest away from the means so far
  1306. access::rw(means).col(g) = X.unsafe_col(best_i);
  1307. }
  1308. }
  1309. // get_cout_stream() << "generate_initial_means():" << '\n';
  1310. // means.print();
  1311. }
  1312. template<typename eT>
  1313. template<uword dist_id>
  1314. inline
  1315. void
  1316. gmm_diag<eT>::generate_initial_params(const Mat<eT>& X, const eT var_floor)
  1317. {
  1318. arma_extra_debug_sigprint();
  1319. const uword N_dims = means.n_rows;
  1320. const uword N_gaus = means.n_cols;
  1321. const eT* mah_aux_mem = mah_aux.memptr();
  1322. const uword X_n_cols = X.n_cols;
  1323. if(X_n_cols == 0) { return; }
  1324. // as the covariances are calculated via accumulators,
  1325. // the means also need to be calculated via accumulators to ensure numerical consistency
  1326. Mat<eT> acc_means(N_dims, N_gaus, fill::zeros);
  1327. Mat<eT> acc_dcovs(N_dims, N_gaus, fill::zeros);
  1328. Row<uword> acc_hefts(N_gaus, fill::zeros);
  1329. uword* acc_hefts_mem = acc_hefts.memptr();
  1330. #if defined(ARMA_USE_OPENMP)
  1331. {
  1332. const umat boundaries = internal_gen_boundaries(X_n_cols);
  1333. const uword n_threads = boundaries.n_cols;
  1334. field< Mat<eT> > t_acc_means(n_threads);
  1335. field< Mat<eT> > t_acc_dcovs(n_threads);
  1336. field< Row<uword> > t_acc_hefts(n_threads);
  1337. for(uword t=0; t < n_threads; ++t)
  1338. {
  1339. t_acc_means(t).zeros(N_dims, N_gaus);
  1340. t_acc_dcovs(t).zeros(N_dims, N_gaus);
  1341. t_acc_hefts(t).zeros(N_gaus);
  1342. }
  1343. #pragma omp parallel for schedule(static)
  1344. for(uword t=0; t < n_threads; ++t)
  1345. {
  1346. uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
  1347. const uword start_index = boundaries.at(0,t);
  1348. const uword end_index = boundaries.at(1,t);
  1349. for(uword i=start_index; i <= end_index; ++i)
  1350. {
  1351. const eT* X_colptr = X.colptr(i);
  1352. eT min_dist = Datum<eT>::inf;
  1353. uword best_g = 0;
  1354. for(uword g=0; g<N_gaus; ++g)
  1355. {
  1356. const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
  1357. if(dist < min_dist) { min_dist = dist; best_g = g; }
  1358. }
  1359. eT* t_acc_mean = t_acc_means(t).colptr(best_g);
  1360. eT* t_acc_dcov = t_acc_dcovs(t).colptr(best_g);
  1361. for(uword d=0; d<N_dims; ++d)
  1362. {
  1363. const eT x_d = X_colptr[d];
  1364. t_acc_mean[d] += x_d;
  1365. t_acc_dcov[d] += x_d*x_d;
  1366. }
  1367. t_acc_hefts_mem[best_g]++;
  1368. }
  1369. }
  1370. // reduction
  1371. acc_means = t_acc_means(0);
  1372. acc_dcovs = t_acc_dcovs(0);
  1373. acc_hefts = t_acc_hefts(0);
  1374. for(uword t=1; t < n_threads; ++t)
  1375. {
  1376. acc_means += t_acc_means(t);
  1377. acc_dcovs += t_acc_dcovs(t);
  1378. acc_hefts += t_acc_hefts(t);
  1379. }
  1380. }
  1381. #else
  1382. {
  1383. for(uword i=0; i<X_n_cols; ++i)
  1384. {
  1385. const eT* X_colptr = X.colptr(i);
  1386. eT min_dist = Datum<eT>::inf;
  1387. uword best_g = 0;
  1388. for(uword g=0; g<N_gaus; ++g)
  1389. {
  1390. const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
  1391. if(dist < min_dist) { min_dist = dist; best_g = g; }
  1392. }
  1393. eT* acc_mean = acc_means.colptr(best_g);
  1394. eT* acc_dcov = acc_dcovs.colptr(best_g);
  1395. for(uword d=0; d<N_dims; ++d)
  1396. {
  1397. const eT x_d = X_colptr[d];
  1398. acc_mean[d] += x_d;
  1399. acc_dcov[d] += x_d*x_d;
  1400. }
  1401. acc_hefts_mem[best_g]++;
  1402. }
  1403. }
  1404. #endif
  1405. eT* hefts_mem = access::rw(hefts).memptr();
  1406. for(uword g=0; g<N_gaus; ++g)
  1407. {
  1408. const eT* acc_mean = acc_means.colptr(g);
  1409. const eT* acc_dcov = acc_dcovs.colptr(g);
  1410. const uword acc_heft = acc_hefts_mem[g];
  1411. eT* mean = access::rw(means).colptr(g);
  1412. eT* dcov = access::rw(dcovs).colptr(g);
  1413. for(uword d=0; d<N_dims; ++d)
  1414. {
  1415. const eT tmp = acc_mean[d] / eT(acc_heft);
  1416. mean[d] = (acc_heft >= 1) ? tmp : eT(0);
  1417. dcov[d] = (acc_heft >= 2) ? eT((acc_dcov[d] / eT(acc_heft)) - (tmp*tmp)) : eT(var_floor);
  1418. }
  1419. hefts_mem[g] = eT(acc_heft) / eT(X_n_cols);
  1420. }
  1421. em_fix_params(var_floor);
  1422. }
  1423. //! multi-threaded implementation of k-means, inspired by MapReduce
  1424. template<typename eT>
  1425. template<uword dist_id>
  1426. inline
  1427. bool
  1428. gmm_diag<eT>::km_iterate(const Mat<eT>& X, const uword max_iter, const bool verbose, const char* signature)
  1429. {
  1430. arma_extra_debug_sigprint();
  1431. if(verbose)
  1432. {
  1433. get_cout_stream().unsetf(ios::showbase);
  1434. get_cout_stream().unsetf(ios::uppercase);
  1435. get_cout_stream().unsetf(ios::showpos);
  1436. get_cout_stream().unsetf(ios::scientific);
  1437. get_cout_stream().setf(ios::right);
  1438. get_cout_stream().setf(ios::fixed);
  1439. }
  1440. const uword X_n_cols = X.n_cols;
  1441. if(X_n_cols == 0) { return true; }
  1442. const uword N_dims = means.n_rows;
  1443. const uword N_gaus = means.n_cols;
  1444. const eT* mah_aux_mem = mah_aux.memptr();
  1445. Mat<eT> acc_means(N_dims, N_gaus, fill::zeros);
  1446. Row<uword> acc_hefts(N_gaus, fill::zeros);
  1447. Row<uword> last_indx(N_gaus, fill::zeros);
  1448. Mat<eT> new_means = means;
  1449. Mat<eT> old_means = means;
  1450. running_mean_scalar<eT> rs_delta;
  1451. #if defined(ARMA_USE_OPENMP)
  1452. const umat boundaries = internal_gen_boundaries(X_n_cols);
  1453. const uword n_threads = boundaries.n_cols;
  1454. field< Mat<eT> > t_acc_means(n_threads);
  1455. field< Row<uword> > t_acc_hefts(n_threads);
  1456. field< Row<uword> > t_last_indx(n_threads);
  1457. #else
  1458. const uword n_threads = 1;
  1459. #endif
  1460. if(verbose) { get_cout_stream() << signature << ": n_threads: " << n_threads << '\n'; get_cout_stream().flush(); }
  1461. for(uword iter=1; iter <= max_iter; ++iter)
  1462. {
  1463. #if defined(ARMA_USE_OPENMP)
  1464. {
  1465. for(uword t=0; t < n_threads; ++t)
  1466. {
  1467. t_acc_means(t).zeros(N_dims, N_gaus);
  1468. t_acc_hefts(t).zeros(N_gaus);
  1469. t_last_indx(t).zeros(N_gaus);
  1470. }
  1471. #pragma omp parallel for schedule(static)
  1472. for(uword t=0; t < n_threads; ++t)
  1473. {
  1474. Mat<eT>& t_acc_means_t = t_acc_means(t);
  1475. uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
  1476. uword* t_last_indx_mem = t_last_indx(t).memptr();
  1477. const uword start_index = boundaries.at(0,t);
  1478. const uword end_index = boundaries.at(1,t);
  1479. for(uword i=start_index; i <= end_index; ++i)
  1480. {
  1481. const eT* X_colptr = X.colptr(i);
  1482. eT min_dist = Datum<eT>::inf;
  1483. uword best_g = 0;
  1484. for(uword g=0; g<N_gaus; ++g)
  1485. {
  1486. const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
  1487. if(dist < min_dist) { min_dist = dist; best_g = g; }
  1488. }
  1489. eT* t_acc_mean = t_acc_means_t.colptr(best_g);
  1490. for(uword d=0; d<N_dims; ++d) { t_acc_mean[d] += X_colptr[d]; }
  1491. t_acc_hefts_mem[best_g]++;
  1492. t_last_indx_mem[best_g] = i;
  1493. }
  1494. }
  1495. // reduction
  1496. acc_means = t_acc_means(0);
  1497. acc_hefts = t_acc_hefts(0);
  1498. for(uword t=1; t < n_threads; ++t)
  1499. {
  1500. acc_means += t_acc_means(t);
  1501. acc_hefts += t_acc_hefts(t);
  1502. }
  1503. for(uword g=0; g < N_gaus; ++g)
  1504. for(uword t=0; t < n_threads; ++t)
  1505. {
  1506. if( t_acc_hefts(t)(g) >= 1 ) { last_indx(g) = t_last_indx(t)(g); }
  1507. }
  1508. }
  1509. #else
  1510. {
  1511. uword* acc_hefts_mem = acc_hefts.memptr();
  1512. uword* last_indx_mem = last_indx.memptr();
  1513. for(uword i=0; i < X_n_cols; ++i)
  1514. {
  1515. const eT* X_colptr = X.colptr(i);
  1516. eT min_dist = Datum<eT>::inf;
  1517. uword best_g = 0;
  1518. for(uword g=0; g<N_gaus; ++g)
  1519. {
  1520. const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
  1521. if(dist < min_dist) { min_dist = dist; best_g = g; }
  1522. }
  1523. eT* acc_mean = acc_means.colptr(best_g);
  1524. for(uword d=0; d<N_dims; ++d) { acc_mean[d] += X_colptr[d]; }
  1525. acc_hefts_mem[best_g]++;
  1526. last_indx_mem[best_g] = i;
  1527. }
  1528. }
  1529. #endif
  1530. // generate new means
  1531. uword* acc_hefts_mem = acc_hefts.memptr();
  1532. for(uword g=0; g < N_gaus; ++g)
  1533. {
  1534. const eT* acc_mean = acc_means.colptr(g);
  1535. const uword acc_heft = acc_hefts_mem[g];
  1536. eT* new_mean = access::rw(new_means).colptr(g);
  1537. for(uword d=0; d<N_dims; ++d)
  1538. {
  1539. new_mean[d] = (acc_heft >= 1) ? (acc_mean[d] / eT(acc_heft)) : eT(0);
  1540. }
  1541. }
  1542. // heuristics to resurrect dead means
  1543. const uvec dead_gs = find(acc_hefts == uword(0));
  1544. if(dead_gs.n_elem > 0)
  1545. {
  1546. if(verbose) { get_cout_stream() << signature << ": recovering from dead means\n"; get_cout_stream().flush(); }
  1547. uword* last_indx_mem = last_indx.memptr();
  1548. const uvec live_gs = sort( find(acc_hefts >= uword(2)), "descend" );
  1549. if(live_gs.n_elem == 0) { return false; }
  1550. uword live_gs_count = 0;
  1551. for(uword dead_gs_count = 0; dead_gs_count < dead_gs.n_elem; ++dead_gs_count)
  1552. {
  1553. const uword dead_g_id = dead_gs(dead_gs_count);
  1554. uword proposed_i = 0;
  1555. if(live_gs_count < live_gs.n_elem)
  1556. {
  1557. const uword live_g_id = live_gs(live_gs_count); ++live_gs_count;
  1558. if(live_g_id == dead_g_id) { return false; }
  1559. // recover by using a sample from a known good mean
  1560. proposed_i = last_indx_mem[live_g_id];
  1561. }
  1562. else
  1563. {
  1564. // recover by using a randomly seleced sample (last resort)
  1565. proposed_i = as_scalar(randi<uvec>(1, distr_param(0,X_n_cols-1)));
  1566. }
  1567. if(proposed_i >= X_n_cols) { return false; }
  1568. new_means.col(dead_g_id) = X.col(proposed_i);
  1569. }
  1570. }
  1571. rs_delta.reset();
  1572. for(uword g=0; g < N_gaus; ++g)
  1573. {
  1574. rs_delta( distance<eT,dist_id>::eval(N_dims, old_means.colptr(g), new_means.colptr(g), mah_aux_mem) );
  1575. }
  1576. if(verbose)
  1577. {
  1578. get_cout_stream() << signature << ": iteration: ";
  1579. get_cout_stream().unsetf(ios::scientific);
  1580. get_cout_stream().setf(ios::fixed);
  1581. get_cout_stream().width(std::streamsize(4));
  1582. get_cout_stream() << iter;
  1583. get_cout_stream() << " delta: ";
  1584. get_cout_stream().unsetf(ios::fixed);
  1585. //get_cout_stream().setf(ios::scientific);
  1586. get_cout_stream() << rs_delta.mean() << '\n';
  1587. get_cout_stream().flush();
  1588. }
  1589. arma::swap(old_means, new_means);
  1590. if(rs_delta.mean() <= Datum<eT>::eps) { break; }
  1591. }
  1592. access::rw(means) = old_means;
  1593. if(means.is_finite() == false) { return false; }
  1594. return true;
  1595. }
  1596. //! multi-threaded implementation of Expectation-Maximisation, inspired by MapReduce
  1597. template<typename eT>
  1598. inline
  1599. bool
  1600. gmm_diag<eT>::em_iterate(const Mat<eT>& X, const uword max_iter, const eT var_floor, const bool verbose)
  1601. {
  1602. arma_extra_debug_sigprint();
  1603. if(X.n_cols == 0) { return true; }
  1604. const uword N_dims = means.n_rows;
  1605. const uword N_gaus = means.n_cols;
  1606. if(verbose)
  1607. {
  1608. get_cout_stream().unsetf(ios::showbase);
  1609. get_cout_stream().unsetf(ios::uppercase);
  1610. get_cout_stream().unsetf(ios::showpos);
  1611. get_cout_stream().unsetf(ios::scientific);
  1612. get_cout_stream().setf(ios::right);
  1613. get_cout_stream().setf(ios::fixed);
  1614. }
  1615. const umat boundaries = internal_gen_boundaries(X.n_cols);
  1616. const uword n_threads = boundaries.n_cols;
  1617. field< Mat<eT> > t_acc_means(n_threads);
  1618. field< Mat<eT> > t_acc_dcovs(n_threads);
  1619. field< Col<eT> > t_acc_norm_lhoods(n_threads);
  1620. field< Col<eT> > t_gaus_log_lhoods(n_threads);
  1621. Col<eT> t_progress_log_lhood(n_threads);
  1622. for(uword t=0; t<n_threads; t++)
  1623. {
  1624. t_acc_means[t].set_size(N_dims, N_gaus);
  1625. t_acc_dcovs[t].set_size(N_dims, N_gaus);
  1626. t_acc_norm_lhoods[t].set_size(N_gaus);
  1627. t_gaus_log_lhoods[t].set_size(N_gaus);
  1628. }
  1629. if(verbose)
  1630. {
  1631. get_cout_stream() << "gmm_diag::learn(): EM: n_threads: " << n_threads << '\n';
  1632. }
  1633. eT old_avg_log_p = -Datum<eT>::inf;
  1634. for(uword iter=1; iter <= max_iter; ++iter)
  1635. {
  1636. init_constants();
  1637. em_update_params(X, boundaries, t_acc_means, t_acc_dcovs, t_acc_norm_lhoods, t_gaus_log_lhoods, t_progress_log_lhood);
  1638. em_fix_params(var_floor);
  1639. const eT new_avg_log_p = accu(t_progress_log_lhood) / eT(t_progress_log_lhood.n_elem);
  1640. if(verbose)
  1641. {
  1642. get_cout_stream() << "gmm_diag::learn(): EM: iteration: ";
  1643. get_cout_stream().unsetf(ios::scientific);
  1644. get_cout_stream().setf(ios::fixed);
  1645. get_cout_stream().width(std::streamsize(4));
  1646. get_cout_stream() << iter;
  1647. get_cout_stream() << " avg_log_p: ";
  1648. get_cout_stream().unsetf(ios::fixed);
  1649. //get_cout_stream().setf(ios::scientific);
  1650. get_cout_stream() << new_avg_log_p << '\n';
  1651. get_cout_stream().flush();
  1652. }
  1653. if(arma_isfinite(new_avg_log_p) == false) { return false; }
  1654. if(std::abs(old_avg_log_p - new_avg_log_p) <= Datum<eT>::eps) { break; }
  1655. old_avg_log_p = new_avg_log_p;
  1656. }
  1657. if(any(vectorise(dcovs) <= eT(0))) { return false; }
  1658. if(means.is_finite() == false ) { return false; }
  1659. if(dcovs.is_finite() == false ) { return false; }
  1660. if(hefts.is_finite() == false ) { return false; }
  1661. return true;
  1662. }
  1663. template<typename eT>
  1664. inline
  1665. void
  1666. gmm_diag<eT>::em_update_params
  1667. (
  1668. const Mat<eT>& X,
  1669. const umat& boundaries,
  1670. field< Mat<eT> >& t_acc_means,
  1671. field< Mat<eT> >& t_acc_dcovs,
  1672. field< Col<eT> >& t_acc_norm_lhoods,
  1673. field< Col<eT> >& t_gaus_log_lhoods,
  1674. Col<eT>& t_progress_log_lhood
  1675. )
  1676. {
  1677. arma_extra_debug_sigprint();
  1678. const uword n_threads = boundaries.n_cols;
  1679. // em_generate_acc() is the "map" operation, which produces partial accumulators for means, diagonal covariances and hefts
  1680. #if defined(ARMA_USE_OPENMP)
  1681. {
  1682. #pragma omp parallel for schedule(static)
  1683. for(uword t=0; t<n_threads; t++)
  1684. {
  1685. Mat<eT>& acc_means = t_acc_means[t];
  1686. Mat<eT>& acc_dcovs = t_acc_dcovs[t];
  1687. Col<eT>& acc_norm_lhoods = t_acc_norm_lhoods[t];
  1688. Col<eT>& gaus_log_lhoods = t_gaus_log_lhoods[t];
  1689. eT& progress_log_lhood = t_progress_log_lhood[t];
  1690. em_generate_acc(X, boundaries.at(0,t), boundaries.at(1,t), acc_means, acc_dcovs, acc_norm_lhoods, gaus_log_lhoods, progress_log_lhood);
  1691. }
  1692. }
  1693. #else
  1694. {
  1695. em_generate_acc(X, boundaries.at(0,0), boundaries.at(1,0), t_acc_means[0], t_acc_dcovs[0], t_acc_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]);
  1696. }
  1697. #endif
  1698. const uword N_dims = means.n_rows;
  1699. const uword N_gaus = means.n_cols;
  1700. Mat<eT>& final_acc_means = t_acc_means[0];
  1701. Mat<eT>& final_acc_dcovs = t_acc_dcovs[0];
  1702. Col<eT>& final_acc_norm_lhoods = t_acc_norm_lhoods[0];
  1703. // the "reduce" operation, which combines the partial accumulators produced by the separate threads
  1704. for(uword t=1; t<n_threads; t++)
  1705. {
  1706. final_acc_means += t_acc_means[t];
  1707. final_acc_dcovs += t_acc_dcovs[t];
  1708. final_acc_norm_lhoods += t_acc_norm_lhoods[t];
  1709. }
  1710. eT* hefts_mem = access::rw(hefts).memptr();
  1711. //// update each component without sanity checking
  1712. //for(uword g=0; g < N_gaus; ++g)
  1713. // {
  1714. // const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
  1715. //
  1716. // eT* mean_mem = access::rw(means).colptr(g);
  1717. // eT* dcov_mem = access::rw(dcovs).colptr(g);
  1718. //
  1719. // eT* acc_mean_mem = final_acc_means.colptr(g);
  1720. // eT* acc_dcov_mem = final_acc_dcovs.colptr(g);
  1721. //
  1722. // hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
  1723. //
  1724. // for(uword d=0; d < N_dims; ++d)
  1725. // {
  1726. // const eT tmp = acc_mean_mem[d] / acc_norm_lhood;
  1727. //
  1728. // mean_mem[d] = tmp;
  1729. // dcov_mem[d] = acc_dcov_mem[d] / acc_norm_lhood - tmp*tmp;
  1730. // }
  1731. // }
  1732. // conditionally update each component; if only a subset of the hefts was updated, em_fix_params() will sanitise them
  1733. for(uword g=0; g < N_gaus; ++g)
  1734. {
  1735. const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
  1736. if(arma_isfinite(acc_norm_lhood) == false) { continue; }
  1737. eT* acc_mean_mem = final_acc_means.colptr(g);
  1738. eT* acc_dcov_mem = final_acc_dcovs.colptr(g);
  1739. bool ok = true;
  1740. for(uword d=0; d < N_dims; ++d)
  1741. {
  1742. const eT tmp1 = acc_mean_mem[d] / acc_norm_lhood;
  1743. const eT tmp2 = acc_dcov_mem[d] / acc_norm_lhood - tmp1*tmp1;
  1744. acc_mean_mem[d] = tmp1;
  1745. acc_dcov_mem[d] = tmp2;
  1746. if(arma_isfinite(tmp2) == false) { ok = false; }
  1747. }
  1748. if(ok)
  1749. {
  1750. hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
  1751. eT* mean_mem = access::rw(means).colptr(g);
  1752. eT* dcov_mem = access::rw(dcovs).colptr(g);
  1753. for(uword d=0; d < N_dims; ++d)
  1754. {
  1755. mean_mem[d] = acc_mean_mem[d];
  1756. dcov_mem[d] = acc_dcov_mem[d];
  1757. }
  1758. }
  1759. }
  1760. }
  1761. template<typename eT>
  1762. inline
  1763. void
  1764. gmm_diag<eT>::em_generate_acc
  1765. (
  1766. const Mat<eT>& X,
  1767. const uword start_index,
  1768. const uword end_index,
  1769. Mat<eT>& acc_means,
  1770. Mat<eT>& acc_dcovs,
  1771. Col<eT>& acc_norm_lhoods,
  1772. Col<eT>& gaus_log_lhoods,
  1773. eT& progress_log_lhood
  1774. )
  1775. const
  1776. {
  1777. arma_extra_debug_sigprint();
  1778. progress_log_lhood = eT(0);
  1779. acc_means.zeros();
  1780. acc_dcovs.zeros();
  1781. acc_norm_lhoods.zeros();
  1782. gaus_log_lhoods.zeros();
  1783. const uword N_dims = means.n_rows;
  1784. const uword N_gaus = means.n_cols;
  1785. const eT* log_hefts_mem = log_hefts.memptr();
  1786. eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr();
  1787. for(uword i=start_index; i <= end_index; i++)
  1788. {
  1789. const eT* x = X.colptr(i);
  1790. for(uword g=0; g < N_gaus; ++g)
  1791. {
  1792. gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[g];
  1793. }
  1794. eT log_lhood_sum = gaus_log_lhoods_mem[0];
  1795. for(uword g=1; g < N_gaus; ++g)
  1796. {
  1797. log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]);
  1798. }
  1799. progress_log_lhood += log_lhood_sum;
  1800. for(uword g=0; g < N_gaus; ++g)
  1801. {
  1802. const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum);
  1803. acc_norm_lhoods[g] += norm_lhood;
  1804. eT* acc_mean_mem = acc_means.colptr(g);
  1805. eT* acc_dcov_mem = acc_dcovs.colptr(g);
  1806. for(uword d=0; d < N_dims; ++d)
  1807. {
  1808. const eT x_d = x[d];
  1809. const eT y_d = x_d * norm_lhood;
  1810. acc_mean_mem[d] += y_d;
  1811. acc_dcov_mem[d] += y_d * x_d; // equivalent to x_d * x_d * norm_lhood
  1812. }
  1813. }
  1814. }
  1815. progress_log_lhood /= eT((end_index - start_index) + 1);
  1816. }
  1817. template<typename eT>
  1818. inline
  1819. void
  1820. gmm_diag<eT>::em_fix_params(const eT var_floor)
  1821. {
  1822. arma_extra_debug_sigprint();
  1823. const uword N_dims = means.n_rows;
  1824. const uword N_gaus = means.n_cols;
  1825. const eT var_ceiling = std::numeric_limits<eT>::max();
  1826. const uword dcovs_n_elem = dcovs.n_elem;
  1827. eT* dcovs_mem = access::rw(dcovs).memptr();
  1828. for(uword i=0; i < dcovs_n_elem; ++i)
  1829. {
  1830. eT& var_val = dcovs_mem[i];
  1831. if(var_val < var_floor ) { var_val = var_floor; }
  1832. else if(var_val > var_ceiling) { var_val = var_ceiling; }
  1833. else if(arma_isnan(var_val) ) { var_val = eT(1); }
  1834. }
  1835. eT* hefts_mem = access::rw(hefts).memptr();
  1836. for(uword g1=0; g1 < N_gaus; ++g1)
  1837. {
  1838. if(hefts_mem[g1] > eT(0))
  1839. {
  1840. const eT* means_colptr_g1 = means.colptr(g1);
  1841. for(uword g2=(g1+1); g2 < N_gaus; ++g2)
  1842. {
  1843. if( (hefts_mem[g2] > eT(0)) && (std::abs(hefts_mem[g1] - hefts_mem[g2]) <= std::numeric_limits<eT>::epsilon()) )
  1844. {
  1845. const eT dist = distance<eT,1>::eval(N_dims, means_colptr_g1, means.colptr(g2), means_colptr_g1);
  1846. if(dist == eT(0)) { hefts_mem[g2] = eT(0); }
  1847. }
  1848. }
  1849. }
  1850. }
  1851. const eT heft_floor = std::numeric_limits<eT>::min();
  1852. const eT heft_initial = eT(1) / eT(N_gaus);
  1853. for(uword i=0; i < N_gaus; ++i)
  1854. {
  1855. eT& heft_val = hefts_mem[i];
  1856. if(heft_val < heft_floor) { heft_val = heft_floor; }
  1857. else if(heft_val > eT(1) ) { heft_val = eT(1); }
  1858. else if(arma_isnan(heft_val) ) { heft_val = heft_initial; }
  1859. }
  1860. const eT heft_sum = accu(hefts);
  1861. if((heft_sum < (eT(1) - Datum<eT>::eps)) || (heft_sum > (eT(1) + Datum<eT>::eps))) { access::rw(hefts) /= heft_sum; }
  1862. }
  1863. } // namespace gmm_priv
  1864. //! @}