test_rtrees.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include "test_precomp.hpp"
  5. namespace opencv_test { namespace {
  6. TEST(ML_RTrees, getVotes)
  7. {
  8. int n = 12;
  9. int count, i;
  10. int label_size = 3;
  11. int predicted_class = 0;
  12. int max_votes = -1;
  13. int val;
  14. // RTrees for classification
  15. Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
  16. //data
  17. Mat data(n, 4, CV_32F);
  18. randu(data, 0, 10);
  19. //labels
  20. Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
  21. rt->train(data, ml::ROW_SAMPLE, labels);
  22. //run function
  23. Mat test(1, 4, CV_32F);
  24. Mat result;
  25. randu(test, 0, 10);
  26. rt->getVotes(test, result, 0);
  27. //count vote amount and find highest vote
  28. count = 0;
  29. const int* result_row = result.ptr<int>(1);
  30. for( i = 0; i < label_size; i++ )
  31. {
  32. val = result_row[i];
  33. //predicted_class = max_votes < val? i;
  34. if( max_votes < val )
  35. {
  36. max_votes = val;
  37. predicted_class = i;
  38. }
  39. count += val;
  40. }
  41. EXPECT_EQ(count, (int)rt->getRoots().size());
  42. EXPECT_EQ(result.at<float>(0, predicted_class), rt->predict(test));
  43. }
  44. TEST(ML_RTrees, 11142_sample_weights_regression)
  45. {
  46. int n = 3;
  47. // RTrees for regression
  48. Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
  49. //simple regression problem of x -> 2x
  50. Mat data = (Mat_<float>(n,1) << 1, 2, 3);
  51. Mat values = (Mat_<float>(n,1) << 2, 4, 6);
  52. Mat weights = (Mat_<float>(n, 1) << 10, 10, 10);
  53. Ptr<TrainData> trainData = TrainData::create(data, ml::ROW_SAMPLE, values);
  54. rt->train(trainData);
  55. double error_without_weights = round(rt->getOOBError());
  56. rt->clear();
  57. Ptr<TrainData> trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, values, Mat(), Mat(), weights );
  58. rt->train(trainDataWithWeights);
  59. double error_with_weights = round(rt->getOOBError());
  60. // error with weights should be larger than error without weights
  61. EXPECT_GE(error_with_weights, error_without_weights);
  62. }
  63. TEST(ML_RTrees, 11142_sample_weights_classification)
  64. {
  65. int n = 12;
  66. // RTrees for classification
  67. Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
  68. Mat data(n, 4, CV_32F);
  69. randu(data, 0, 10);
  70. Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
  71. Mat weights = (Mat_<float>(n, 1) << 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10);
  72. rt->train(data, ml::ROW_SAMPLE, labels);
  73. rt->clear();
  74. double error_without_weights = round(rt->getOOBError());
  75. Ptr<TrainData> trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, labels, Mat(), Mat(), weights );
  76. rt->train(data, ml::ROW_SAMPLE, labels);
  77. double error_with_weights = round(rt->getOOBError());
  78. std::cout << error_without_weights << std::endl;
  79. std::cout << error_with_weights << std::endl;
  80. // error with weights should be larger than error without weights
  81. EXPECT_GE(error_with_weights, error_without_weights);
  82. }
  83. TEST(ML_RTrees, bug_12974_throw_exception_when_predict_different_feature_count)
  84. {
  85. int numFeatures = 5;
  86. // create a 5 feature dataset and train the model
  87. cv::Ptr<RTrees> model = RTrees::create();
  88. Mat samples(10, numFeatures, CV_32F);
  89. randu(samples, 0, 10);
  90. Mat labels = (Mat_<int>(10,1) << 0,0,0,0,0,1,1,1,1,1);
  91. cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, labels);
  92. model->train(trainData);
  93. // try to predict on data which have fewer features - this should throw an exception
  94. for(int i = 1; i < numFeatures - 1; ++i) {
  95. Mat test(1, i, CV_32FC1);
  96. ASSERT_THROW(model->predict(test), Exception);
  97. }
  98. // try to predict on data which have more features - this should also throw an exception
  99. Mat test(1, numFeatures + 1, CV_32FC1);
  100. ASSERT_THROW(model->predict(test), Exception);
  101. }
  102. }} // namespace