test_svmsgd.cpp 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include "test_precomp.hpp"
  5. namespace opencv_test { namespace {
  6. static const int TEST_VALUE_LIMIT = 500;
  7. enum
  8. {
  9. UNIFORM_SAME_SCALE,
  10. UNIFORM_DIFFERENT_SCALES
  11. };
  12. CV_ENUM(SVMSGD_TYPE, UNIFORM_SAME_SCALE, UNIFORM_DIFFERENT_SCALES)
  13. typedef std::vector< std::pair<float,float> > BorderList;
  14. static void makeData(RNG &rng, int samplesCount, const Mat &weights, float shift, const BorderList & borders, Mat &samples, Mat & responses)
  15. {
  16. int featureCount = weights.cols;
  17. samples.create(samplesCount, featureCount, CV_32FC1);
  18. for (int featureIndex = 0; featureIndex < featureCount; featureIndex++)
  19. rng.fill(samples.col(featureIndex), RNG::UNIFORM, borders[featureIndex].first, borders[featureIndex].second);
  20. responses.create(samplesCount, 1, CV_32FC1);
  21. for (int i = 0 ; i < samplesCount; i++)
  22. {
  23. double res = samples.row(i).dot(weights) + shift;
  24. responses.at<float>(i) = res > 0 ? 1.f : -1.f;
  25. }
  26. }
  27. //==================================================================================================
  28. typedef tuple<SVMSGD_TYPE, int, double> ML_SVMSGD_Param;
  29. typedef testing::TestWithParam<ML_SVMSGD_Param> ML_SVMSGD_Params;
  30. TEST_P(ML_SVMSGD_Params, scale_and_features)
  31. {
  32. const int type = get<0>(GetParam());
  33. const int featureCount = get<1>(GetParam());
  34. const double precision = get<2>(GetParam());
  35. RNG &rng = cv::theRNG();
  36. Mat_<float> weights(1, featureCount);
  37. rng.fill(weights, RNG::UNIFORM, -1, 1);
  38. const float shift = static_cast<float>(rng.uniform(-featureCount, featureCount));
  39. BorderList borders;
  40. float lowerLimit = -TEST_VALUE_LIMIT;
  41. float upperLimit = TEST_VALUE_LIMIT;
  42. if (type == UNIFORM_SAME_SCALE)
  43. {
  44. for (int featureIndex = 0; featureIndex < featureCount; featureIndex++)
  45. borders.push_back(std::pair<float,float>(lowerLimit, upperLimit));
  46. }
  47. else if (type == UNIFORM_DIFFERENT_SCALES)
  48. {
  49. for (int featureIndex = 0; featureIndex < featureCount; featureIndex++)
  50. {
  51. int crit = rng.uniform(0, 2);
  52. if (crit > 0)
  53. borders.push_back(std::pair<float,float>(lowerLimit, upperLimit));
  54. else
  55. borders.push_back(std::pair<float,float>(lowerLimit/1000, upperLimit/1000));
  56. }
  57. }
  58. ASSERT_FALSE(borders.empty());
  59. Mat trainSamples;
  60. Mat trainResponses;
  61. int trainSamplesCount = 10000;
  62. makeData(rng, trainSamplesCount, weights, shift, borders, trainSamples, trainResponses);
  63. ASSERT_EQ(trainResponses.type(), CV_32FC1);
  64. Mat testSamples;
  65. Mat testResponses;
  66. int testSamplesCount = 100000;
  67. makeData(rng, testSamplesCount, weights, shift, borders, testSamples, testResponses);
  68. ASSERT_EQ(testResponses.type(), CV_32FC1);
  69. Ptr<TrainData> data = TrainData::create(trainSamples, cv::ml::ROW_SAMPLE, trainResponses);
  70. ASSERT_TRUE(data);
  71. cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
  72. ASSERT_TRUE(svmsgd);
  73. svmsgd->train(data);
  74. Mat responses;
  75. svmsgd->predict(testSamples, responses);
  76. ASSERT_EQ(responses.type(), CV_32FC1);
  77. ASSERT_EQ(responses.rows, testSamplesCount);
  78. int errCount = 0;
  79. for (int i = 0; i < testSamplesCount; i++)
  80. if (responses.at<float>(i) * testResponses.at<float>(i) < 0)
  81. errCount++;
  82. float err = (float)errCount / testSamplesCount;
  83. EXPECT_LE(err, precision);
  84. }
  85. ML_SVMSGD_Param params_list[] = {
  86. ML_SVMSGD_Param(UNIFORM_SAME_SCALE, 2, 0.01),
  87. ML_SVMSGD_Param(UNIFORM_SAME_SCALE, 5, 0.01),
  88. ML_SVMSGD_Param(UNIFORM_SAME_SCALE, 100, 0.02),
  89. ML_SVMSGD_Param(UNIFORM_DIFFERENT_SCALES, 2, 0.01),
  90. ML_SVMSGD_Param(UNIFORM_DIFFERENT_SCALES, 5, 0.01),
  91. ML_SVMSGD_Param(UNIFORM_DIFFERENT_SCALES, 100, 0.01),
  92. };
  93. INSTANTIATE_TEST_CASE_P(/**/, ML_SVMSGD_Params, testing::ValuesIn(params_list));
  94. //==================================================================================================
  95. TEST(ML_SVMSGD, twoPoints)
  96. {
  97. Mat samples(2, 2, CV_32FC1);
  98. samples.at<float>(0,0) = 0;
  99. samples.at<float>(0,1) = 0;
  100. samples.at<float>(1,0) = 1000;
  101. samples.at<float>(1,1) = 1;
  102. Mat responses(2, 1, CV_32FC1);
  103. responses.at<float>(0) = -1;
  104. responses.at<float>(1) = 1;
  105. cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
  106. Mat realWeights(1, 2, CV_32FC1);
  107. realWeights.at<float>(0) = 1000;
  108. realWeights.at<float>(1) = 1;
  109. float realShift = -500000.5;
  110. float normRealWeights = static_cast<float>(cv::norm(realWeights)); // TODO cvtest
  111. realWeights /= normRealWeights;
  112. realShift /= normRealWeights;
  113. cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
  114. svmsgd->setOptimalParameters();
  115. svmsgd->train( trainData );
  116. Mat foundWeights = svmsgd->getWeights();
  117. float foundShift = svmsgd->getShift();
  118. float normFoundWeights = static_cast<float>(cv::norm(foundWeights)); // TODO cvtest
  119. foundWeights /= normFoundWeights;
  120. foundShift /= normFoundWeights;
  121. EXPECT_LE(cv::norm(Mat(foundWeights - realWeights)), 0.001); // TODO cvtest
  122. EXPECT_LE(std::abs((foundShift - realShift) / realShift), 0.05);
  123. }
  124. }} // namespace