test_knearest.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include "test_precomp.hpp"
  5. namespace opencv_test { namespace {
  6. using cv::ml::TrainData;
  7. using cv::ml::EM;
  8. using cv::ml::KNearest;
  9. TEST(ML_KNearest, accuracy)
  10. {
  11. int sizesArr[] = { 500, 700, 800 };
  12. int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];
  13. Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;
  14. vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
  15. Mat means;
  16. vector<Mat> covs;
  17. defaultDistribs( means, covs );
  18. generateData( trainData, trainLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );
  19. Mat testData( pointsCount, 2, CV_32FC1 );
  20. Mat testLabels;
  21. generateData( testData, testLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );
  22. {
  23. SCOPED_TRACE("Default");
  24. Mat bestLabels;
  25. float err = 1000;
  26. Ptr<KNearest> knn = KNearest::create();
  27. knn->train(trainData, ml::ROW_SAMPLE, trainLabels);
  28. knn->findNearest(testData, 4, bestLabels);
  29. EXPECT_TRUE(calcErr( bestLabels, testLabels, sizes, err, true ));
  30. EXPECT_LE(err, 0.01f);
  31. }
  32. {
  33. SCOPED_TRACE("KDTree");
  34. Mat neighborIndexes;
  35. float err = 1000;
  36. Ptr<KNearest> knn = KNearest::create();
  37. knn->setAlgorithmType(KNearest::KDTREE);
  38. knn->train(trainData, ml::ROW_SAMPLE, trainLabels);
  39. knn->findNearest(testData, 4, neighborIndexes);
  40. Mat bestLabels;
  41. // The output of the KDTree are the neighbor indexes, not actual class labels
  42. // so we need to do some extra work to get actual predictions
  43. for(int row_num = 0; row_num < neighborIndexes.rows; ++row_num){
  44. vector<float> labels;
  45. for(int index = 0; index < neighborIndexes.row(row_num).cols; ++index) {
  46. labels.push_back(trainLabels.at<float>(neighborIndexes.row(row_num).at<int>(0, index) , 0));
  47. }
  48. // computing the mode of the output class predictions to determine overall prediction
  49. std::vector<int> histogram(3,0);
  50. for( int i=0; i<3; ++i )
  51. ++histogram[ static_cast<int>(labels[i]) ];
  52. int bestLabel = static_cast<int>(std::max_element( histogram.begin(), histogram.end() ) - histogram.begin());
  53. bestLabels.push_back(bestLabel);
  54. }
  55. bestLabels.convertTo(bestLabels, testLabels.type());
  56. EXPECT_TRUE(calcErr( bestLabels, testLabels, sizes, err, true ));
  57. EXPECT_LE(err, 0.01f);
  58. }
  59. }
  60. TEST(ML_KNearest, regression_12347)
  61. {
  62. Mat xTrainData = (Mat_<float>(5,2) << 1, 1.1, 1.1, 1, 2, 2, 2.1, 2, 2.1, 2.1);
  63. Mat yTrainLabels = (Mat_<float>(5,1) << 1, 1, 2, 2, 2);
  64. Ptr<KNearest> knn = KNearest::create();
  65. knn->train(xTrainData, ml::ROW_SAMPLE, yTrainLabels);
  66. Mat xTestData = (Mat_<float>(2,2) << 1.1, 1.1, 2, 2.2);
  67. Mat zBestLabels, neighbours, dist;
  68. // check output shapes:
  69. int K = 16, Kexp = std::min(K, xTrainData.rows);
  70. knn->findNearest(xTestData, K, zBestLabels, neighbours, dist);
  71. EXPECT_EQ(xTestData.rows, zBestLabels.rows);
  72. EXPECT_EQ(neighbours.cols, Kexp);
  73. EXPECT_EQ(dist.cols, Kexp);
  74. // see if the result is still correct:
  75. K = 2;
  76. knn->findNearest(xTestData, K, zBestLabels, neighbours, dist);
  77. EXPECT_EQ(1, zBestLabels.at<float>(0,0));
  78. EXPECT_EQ(2, zBestLabels.at<float>(1,0));
  79. }
  80. TEST(ML_KNearest, bug_11877)
  81. {
  82. Mat trainData = (Mat_<float>(5,2) << 3, 3, 3, 3, 4, 4, 4, 4, 4, 4);
  83. Mat trainLabels = (Mat_<float>(5,1) << 0, 0, 1, 1, 1);
  84. Ptr<KNearest> knnKdt = KNearest::create();
  85. knnKdt->setAlgorithmType(KNearest::KDTREE);
  86. knnKdt->setIsClassifier(true);
  87. knnKdt->train(trainData, ml::ROW_SAMPLE, trainLabels);
  88. Mat testData = (Mat_<float>(2,2) << 3.1, 3.1, 4, 4.1);
  89. Mat testLabels = (Mat_<int>(2,1) << 0, 1);
  90. Mat result;
  91. knnKdt->findNearest(testData, 1, result);
  92. EXPECT_EQ(1, int(result.at<int>(0, 0)));
  93. EXPECT_EQ(2, int(result.at<int>(1, 0)));
  94. EXPECT_EQ(0, trainLabels.at<int>(result.at<int>(0, 0), 0));
  95. }
  96. }} // namespace