gmm.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. #include <vector>
  2. #include <armadillo>
  3. #include "catch.hpp"
  4. using namespace std;
  5. using namespace arma;
  6. /**
  7. * Make sure that gmm_full can fit manually constructed Gaussians.
  8. */
  9. TEST_CASE("gmm_full_1")
  10. {
  11. // Higher dimensionality gives us a greater chance of having separated Gaussians.
  12. const uword dims = 8;
  13. const uword gaussians = 3;
  14. const uword maxTrials = 3;
  15. // Generate dataset.
  16. mat data(dims, 500, fill::zeros);
  17. vector<vec> means(gaussians);
  18. vector<mat> covars(gaussians);
  19. vec weights(gaussians);
  20. uvec counts(gaussians);
  21. bool success = true;
  22. for(uword trial = 0; trial < maxTrials; ++trial)
  23. {
  24. // Preset weights.
  25. weights[0] = 0.25;
  26. weights[1] = 0.325;
  27. weights[2] = 0.425;
  28. for(size_t i = 0; i < gaussians; i++)
  29. {
  30. counts[i] = data.n_cols * weights[i];
  31. }
  32. // Account for rounding errors (possibly necessary).
  33. counts[gaussians - 1] += (data.n_cols - accu(counts));
  34. // Build each Gaussian individually.
  35. size_t point = 0;
  36. for(size_t i = 0; i < gaussians; i++)
  37. {
  38. mat gaussian;
  39. gaussian.randn(dims, counts[i]);
  40. // Randomly generate mean and covariance.
  41. means[i].randu(dims);
  42. means[i] -= 0.5;
  43. means[i] *= (5 * i);
  44. // We need to make sure the covariance is positive definite.
  45. // We will take a random matrix C and then set our covariance to C * C',
  46. // which will be positive semidefinite.
  47. covars[i].randu(dims, dims);
  48. covars[i] += 0.5 * eye<mat>(dims, dims);
  49. covars[i] *= trans(covars[i]);
  50. data.cols(point, point + counts[i] - 1) = (covars[i] * gaussian + means[i] * ones<rowvec>(counts[i]));
  51. // Calculate the actual means and covariances
  52. // because they will probably be different.
  53. means[i] = mean(data.cols(point, point + counts[i] - 1), 1);
  54. covars[i] = cov(data.cols(point, point + counts[i] - 1).t(), 1 /* biased */);
  55. point += counts[i];
  56. }
  57. // Calculate actual weights.
  58. for(size_t i = 0; i < gaussians; i++)
  59. {
  60. weights[i] = (double) counts[i] / data.n_cols;
  61. }
  62. gmm_full gmm;
  63. gmm.learn(data, gaussians, eucl_dist, random_subset, 10, 500, 1e-10, false);
  64. uvec sortRef = sort_index(weights);
  65. uvec sortTry = sort_index(gmm.hefts);
  66. // Check the model to see that it is correct.
  67. success = ( gmm.hefts.n_elem == gaussians );
  68. for(size_t i = 0; i < gaussians; i++)
  69. {
  70. // Check weight.
  71. success &= ( weights[sortRef[i]] == Approx(gmm.hefts[sortTry[i]]).epsilon(0.1) );
  72. for(uword j = 0; j < gmm.means.n_rows; ++j)
  73. {
  74. success &= ( means[sortRef[i]][j] == Approx(gmm.means(j, sortTry[i])).epsilon(0.1) );
  75. }
  76. for(uword j = 0; j < gmm.fcovs.n_rows * gmm.fcovs.n_cols; ++j)
  77. {
  78. success &= ( covars[sortRef[i]][j] == Approx(gmm.fcovs.slice(sortTry[i])[j]).epsilon(0.1) );
  79. }
  80. if(success == false) { continue; }
  81. }
  82. if(success) { break; }
  83. }
  84. REQUIRE( success == true );
  85. }
  86. TEST_CASE("gmm_diag_1")
  87. {
  88. // Higher dimensionality gives us a greater chance of having separated Gaussians.
  89. const uword dims = 4;
  90. const uword gaussians = 3;
  91. const uword maxTrials = 8; // Needs more trials...
  92. // Generate dataset.
  93. mat data(dims, 500, fill::zeros);
  94. vector<vec> means(gaussians);
  95. vector<mat> covars(gaussians);
  96. vec weights(gaussians);
  97. uvec counts(gaussians);
  98. bool success = true;
  99. for(uword trial = 0; trial < maxTrials; ++trial)
  100. {
  101. // Preset weights.
  102. weights[0] = 0.25;
  103. weights[1] = 0.325;
  104. weights[2] = 0.425;
  105. for(size_t i = 0; i < gaussians; i++)
  106. {
  107. counts[i] = data.n_cols * weights[i];
  108. }
  109. // Account for rounding errors (possibly necessary).
  110. counts[gaussians - 1] += (data.n_cols - accu(counts));
  111. // Build each Gaussian individually.
  112. size_t point = 0;
  113. for(size_t i = 0; i < gaussians; i++)
  114. {
  115. mat gaussian;
  116. gaussian.randn(dims, counts[i]);
  117. // Randomly generate mean and covariance.
  118. means[i].randu(dims);
  119. means[i] -= 0.5;
  120. means[i] *= (3 * (i + 1));
  121. // Use a diagonal covariance matrix.
  122. covars[i].zeros(dims, dims);
  123. covars[i].diag() = 0.5 * randu<vec>(dims) + 0.5;
  124. data.cols(point, point + counts[i] - 1) = (covars[i] * gaussian + means[i] * ones<rowvec>(counts[i]));
  125. // Calculate the actual means and covariances
  126. // because they will probably be different.
  127. means[i] = mean(data.cols(point, point + counts[i] - 1), 1);
  128. covars[i] = cov(data.cols(point, point + counts[i] - 1).t(), 1 /* biased */);
  129. point += counts[i];
  130. }
  131. // Calculate actual weights.
  132. for(size_t i = 0; i < gaussians; i++)
  133. {
  134. weights[i] = (double) counts[i] / data.n_cols;
  135. }
  136. gmm_diag gmm;
  137. gmm.learn(data, gaussians, eucl_dist, random_subset, 50, 500, 1e-10, false);
  138. uvec sortRef = sort_index(weights);
  139. uvec sortTry = sort_index(gmm.hefts);
  140. // Check the model to see that it is correct.
  141. success = ( gmm.hefts.n_elem == gaussians );
  142. for(size_t i = 0; i < gaussians; i++)
  143. {
  144. // Check weight.
  145. success &= ( weights[sortRef[i]] == Approx(gmm.hefts[sortTry[i]]).epsilon(0.1) );
  146. for(uword j = 0; j < gmm.means.n_rows; ++j)
  147. {
  148. success &= ( means[sortRef[i]][j] == Approx(gmm.means(j, sortTry[i])).epsilon(0.1) );
  149. }
  150. for(uword j = 0; j < gmm.dcovs.n_rows; ++j)
  151. {
  152. success &= ( covars[sortRef[i]](j, j) == Approx(gmm.dcovs.col(sortTry[i])[j]).epsilon(0.1) );
  153. }
  154. if(success == false) { continue; }
  155. }
  156. if(success) { break; }
  157. }
  158. REQUIRE( success == true );
  159. }