test_em.cpp 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include "test_precomp.hpp"
  5. namespace opencv_test { namespace {
  6. CV_ENUM(EM_START_STEP, EM::START_AUTO_STEP, EM::START_M_STEP, EM::START_E_STEP)
  7. CV_ENUM(EM_COV_MAT, EM::COV_MAT_GENERIC, EM::COV_MAT_DIAGONAL, EM::COV_MAT_SPHERICAL)
  8. typedef testing::TestWithParam< tuple<EM_START_STEP, EM_COV_MAT> > ML_EM_Params;
  9. TEST_P(ML_EM_Params, accuracy)
  10. {
  11. const int nclusters = 3;
  12. const int sizesArr[] = { 500, 700, 800 };
  13. const vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
  14. const int pointsCount = sizesArr[0] + sizesArr[1] + sizesArr[2];
  15. Mat means;
  16. vector<Mat> covs;
  17. defaultDistribs( means, covs, CV_64FC1 );
  18. Mat trainData(pointsCount, 2, CV_64FC1 );
  19. Mat trainLabels;
  20. generateData( trainData, trainLabels, sizes, means, covs, CV_64FC1, CV_32SC1 );
  21. Mat testData( pointsCount, 2, CV_64FC1 );
  22. Mat testLabels;
  23. generateData( testData, testLabels, sizes, means, covs, CV_64FC1, CV_32SC1 );
  24. Mat probs(trainData.rows, nclusters, CV_64FC1, cv::Scalar(1));
  25. Mat weights(1, nclusters, CV_64FC1, cv::Scalar(1));
  26. TermCriteria termCrit(cv::TermCriteria::COUNT + cv::TermCriteria::EPS, 100, FLT_EPSILON);
  27. int startStep = get<0>(GetParam());
  28. int covMatType = get<1>(GetParam());
  29. cv::Mat labels;
  30. Ptr<EM> em = EM::create();
  31. em->setClustersNumber(nclusters);
  32. em->setCovarianceMatrixType(covMatType);
  33. em->setTermCriteria(termCrit);
  34. if( startStep == EM::START_AUTO_STEP )
  35. em->trainEM( trainData, noArray(), labels, noArray() );
  36. else if( startStep == EM::START_E_STEP )
  37. em->trainE( trainData, means, covs, weights, noArray(), labels, noArray() );
  38. else if( startStep == EM::START_M_STEP )
  39. em->trainM( trainData, probs, noArray(), labels, noArray() );
  40. {
  41. SCOPED_TRACE("Train");
  42. float err = 1000;
  43. EXPECT_TRUE(calcErr( labels, trainLabels, sizes, err , false, false ));
  44. EXPECT_LE(err, 0.008f);
  45. }
  46. {
  47. SCOPED_TRACE("Test");
  48. float err = 1000;
  49. labels.create( testData.rows, 1, CV_32SC1 );
  50. for( int i = 0; i < testData.rows; i++ )
  51. {
  52. Mat sample = testData.row(i);
  53. Mat out_probs;
  54. labels.at<int>(i) = static_cast<int>(em->predict2( sample, out_probs )[1]);
  55. }
  56. EXPECT_TRUE(calcErr( labels, testLabels, sizes, err, false, false ));
  57. EXPECT_LE(err, 0.008f);
  58. }
  59. }
  60. INSTANTIATE_TEST_CASE_P(/**/, ML_EM_Params,
  61. testing::Combine(
  62. testing::Values(EM::START_AUTO_STEP, EM::START_M_STEP, EM::START_E_STEP),
  63. testing::Values(EM::COV_MAT_GENERIC, EM::COV_MAT_DIAGONAL, EM::COV_MAT_SPHERICAL)
  64. ));
  65. //==================================================================================================
  66. TEST(ML_EM, save_load)
  67. {
  68. const int nclusters = 2;
  69. Mat_<double> samples(3, 1);
  70. samples << 1., 2., 3.;
  71. std::vector<double> firstResult;
  72. string filename = cv::tempfile(".xml");
  73. {
  74. Mat labels;
  75. Ptr<EM> em = EM::create();
  76. em->setClustersNumber(nclusters);
  77. em->trainEM(samples, noArray(), labels, noArray());
  78. for( int i = 0; i < samples.rows; i++)
  79. {
  80. Vec2d res = em->predict2(samples.row(i), noArray());
  81. firstResult.push_back(res[1]);
  82. }
  83. {
  84. FileStorage fs = FileStorage(filename, FileStorage::WRITE);
  85. ASSERT_NO_THROW(fs << "em" << "{");
  86. ASSERT_NO_THROW(em->write(fs));
  87. ASSERT_NO_THROW(fs << "}");
  88. }
  89. }
  90. {
  91. Ptr<EM> em;
  92. ASSERT_NO_THROW(em = Algorithm::load<EM>(filename));
  93. for( int i = 0; i < samples.rows; i++)
  94. {
  95. SCOPED_TRACE(i);
  96. Vec2d res = em->predict2(samples.row(i), noArray());
  97. EXPECT_DOUBLE_EQ(firstResult[i], res[1]);
  98. }
  99. }
  100. remove(filename.c_str());
  101. }
  102. //==================================================================================================
  103. TEST(ML_EM, classification)
  104. {
  105. // This test classifies spam by the following way:
  106. // 1. estimates distributions of "spam" / "not spam"
  107. // 2. predict classID using Bayes classifier for estimated distributions.
  108. string dataFilename = findDataFile("spambase.data");
  109. Ptr<TrainData> data = TrainData::loadFromCSV(dataFilename, 0);
  110. ASSERT_FALSE(data.empty());
  111. Mat samples = data->getSamples();
  112. ASSERT_EQ(samples.cols, 57);
  113. Mat responses = data->getResponses();
  114. vector<int> trainSamplesMask(samples.rows, 0);
  115. const int trainSamplesCount = (int)(0.5f * samples.rows);
  116. const int testSamplesCount = samples.rows - trainSamplesCount;
  117. for(int i = 0; i < trainSamplesCount; i++)
  118. trainSamplesMask[i] = 1;
  119. RNG &rng = cv::theRNG();
  120. for(size_t i = 0; i < trainSamplesMask.size(); i++)
  121. {
  122. int i1 = rng(static_cast<unsigned>(trainSamplesMask.size()));
  123. int i2 = rng(static_cast<unsigned>(trainSamplesMask.size()));
  124. std::swap(trainSamplesMask[i1], trainSamplesMask[i2]);
  125. }
  126. Mat samples0, samples1;
  127. for(int i = 0; i < samples.rows; i++)
  128. {
  129. if(trainSamplesMask[i])
  130. {
  131. Mat sample = samples.row(i);
  132. int resp = (int)responses.at<float>(i);
  133. if(resp == 0)
  134. samples0.push_back(sample);
  135. else
  136. samples1.push_back(sample);
  137. }
  138. }
  139. Ptr<EM> model0 = EM::create();
  140. model0->setClustersNumber(3);
  141. model0->trainEM(samples0, noArray(), noArray(), noArray());
  142. Ptr<EM> model1 = EM::create();
  143. model1->setClustersNumber(3);
  144. model1->trainEM(samples1, noArray(), noArray(), noArray());
  145. // confusion matrices
  146. Mat_<int> trainCM(2, 2, 0);
  147. Mat_<int> testCM(2, 2, 0);
  148. const double lambda = 1.;
  149. for(int i = 0; i < samples.rows; i++)
  150. {
  151. Mat sample = samples.row(i);
  152. double sampleLogLikelihoods0 = model0->predict2(sample, noArray())[0];
  153. double sampleLogLikelihoods1 = model1->predict2(sample, noArray())[0];
  154. int classID = (sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1) ? 0 : 1;
  155. int resp = (int)responses.at<float>(i);
  156. EXPECT_TRUE(resp == 0 || resp == 1);
  157. if(trainSamplesMask[i])
  158. trainCM(resp, classID)++;
  159. else
  160. testCM(resp, classID)++;
  161. }
  162. EXPECT_LE((double)(trainCM(1,0) + trainCM(0,1)) / trainSamplesCount, 0.23);
  163. EXPECT_LE((double)(testCM(1,0) + testCM(0,1)) / testSamplesCount, 0.26);
  164. }
  165. }} // namespace