test_mltests.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include "test_precomp.hpp"
  5. namespace opencv_test { namespace {
  6. struct DatasetDesc
  7. {
  8. string name;
  9. int resp_idx;
  10. int train_count;
  11. int cat_num;
  12. string type_desc;
  13. public:
  14. Ptr<TrainData> load()
  15. {
  16. string filename = findDataFile(name + ".data");
  17. Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, resp_idx, resp_idx + 1, type_desc);
  18. data->setTrainTestSplit(train_count);
  19. data->shuffleTrainTest();
  20. return data;
  21. }
  22. };
  23. // see testdata/ml/protocol.txt (?)
  24. DatasetDesc datasets[] = {
  25. { "mushroom", 0, 4000, 16, "cat" },
  26. { "adult", 14, 22561, 16, "ord[0,2,4,10-12],cat[1,3,5-9,13,14]" },
  27. { "vehicle", 18, 761, 4, "ord[0-17],cat[18]" },
  28. { "abalone", 8, 3133, 16, "ord[1-8],cat[0]" },
  29. { "ringnorm", 20, 300, 2, "ord[0-19],cat[20]" },
  30. { "spambase", 57, 3221, 3, "ord[0-56],cat[57]" },
  31. { "waveform", 21, 300, 3, "ord[0-20],cat[21]" },
  32. { "elevators", 18, 5000, 0, "ord" },
  33. { "letter", 16, 10000, 26, "ord[0-15],cat[16]" },
  34. { "twonorm", 20, 300, 3, "ord[0-19],cat[20]" },
  35. { "poletelecomm", 48, 2500, 0, "ord" },
  36. };
  37. static DatasetDesc & getDataset(const string & name)
  38. {
  39. const int sz = sizeof(datasets)/sizeof(datasets[0]);
  40. for (int i = 0; i < sz; ++i)
  41. {
  42. DatasetDesc & desc = datasets[i];
  43. if (desc.name == name)
  44. return desc;
  45. }
  46. CV_Error(Error::StsInternal, "");
  47. }
  48. //==================================================================================================
  49. // interfaces and templates
  50. template <typename T> string modelName() { return "Unknown"; };
  51. template <typename T> Ptr<T> tuneModel(const DatasetDesc &, Ptr<T> m) { return m; }
  52. struct IModelFactory
  53. {
  54. virtual Ptr<StatModel> createNew(const DatasetDesc &dataset) const = 0;
  55. virtual Ptr<StatModel> loadFromFile(const string &filename) const = 0;
  56. virtual string name() const = 0;
  57. virtual ~IModelFactory() {}
  58. };
  59. template <typename T>
  60. struct ModelFactory : public IModelFactory
  61. {
  62. Ptr<StatModel> createNew(const DatasetDesc &dataset) const CV_OVERRIDE
  63. {
  64. return tuneModel<T>(dataset, T::create());
  65. }
  66. Ptr<StatModel> loadFromFile(const string & filename) const CV_OVERRIDE
  67. {
  68. return T::load(filename);
  69. }
  70. string name() const CV_OVERRIDE { return modelName<T>(); }
  71. };
  72. // implementation
  73. template <> string modelName<NormalBayesClassifier>() { return "NormalBayesClassifier"; }
  74. template <> string modelName<DTrees>() { return "DTrees"; }
  75. template <> string modelName<KNearest>() { return "KNearest"; }
  76. template <> string modelName<RTrees>() { return "RTrees"; }
  77. template <> string modelName<SVMSGD>() { return "SVMSGD"; }
  78. template<> Ptr<DTrees> tuneModel<DTrees>(const DatasetDesc &dataset, Ptr<DTrees> m)
  79. {
  80. m->setMaxDepth(10);
  81. m->setMinSampleCount(2);
  82. m->setRegressionAccuracy(0);
  83. m->setUseSurrogates(false);
  84. m->setCVFolds(0);
  85. m->setUse1SERule(false);
  86. m->setTruncatePrunedTree(false);
  87. m->setPriors(Mat());
  88. m->setMaxCategories(dataset.cat_num);
  89. return m;
  90. }
  91. template<> Ptr<RTrees> tuneModel<RTrees>(const DatasetDesc &dataset, Ptr<RTrees> m)
  92. {
  93. m->setMaxDepth(20);
  94. m->setMinSampleCount(2);
  95. m->setRegressionAccuracy(0);
  96. m->setUseSurrogates(false);
  97. m->setPriors(Mat());
  98. m->setCalculateVarImportance(true);
  99. m->setActiveVarCount(0);
  100. m->setTermCriteria(TermCriteria(TermCriteria::COUNT, 100, 0.0));
  101. m->setMaxCategories(dataset.cat_num);
  102. return m;
  103. }
  104. template<> Ptr<SVMSGD> tuneModel<SVMSGD>(const DatasetDesc &, Ptr<SVMSGD> m)
  105. {
  106. m->setSvmsgdType(SVMSGD::ASGD);
  107. m->setMarginType(SVMSGD::SOFT_MARGIN);
  108. m->setMarginRegularization(0.00001f);
  109. m->setInitialStepSize(0.1f);
  110. m->setStepDecreasingPower(0.75);
  111. m->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.00001));
  112. return m;
  113. }
  114. template <>
  115. struct ModelFactory<Boost> : public IModelFactory
  116. {
  117. ModelFactory(int boostType_) : boostType(boostType_) {}
  118. Ptr<StatModel> createNew(const DatasetDesc &) const CV_OVERRIDE
  119. {
  120. Ptr<Boost> m = Boost::create();
  121. m->setBoostType(boostType);
  122. m->setWeakCount(20);
  123. m->setWeightTrimRate(0.95);
  124. m->setMaxDepth(4);
  125. m->setUseSurrogates(false);
  126. m->setPriors(Mat());
  127. return m;
  128. }
  129. Ptr<StatModel> loadFromFile(const string &filename) const { return Boost::load(filename); }
  130. string name() const CV_OVERRIDE { return "Boost"; }
  131. int boostType;
  132. };
  133. template <>
  134. struct ModelFactory<SVM> : public IModelFactory
  135. {
  136. ModelFactory(int svmType_, int kernelType_, double gamma_, double c_, double nu_)
  137. : svmType(svmType_), kernelType(kernelType_), gamma(gamma_), c(c_), nu(nu_) {}
  138. Ptr<StatModel> createNew(const DatasetDesc &) const CV_OVERRIDE
  139. {
  140. Ptr<SVM> m = SVM::create();
  141. m->setType(svmType);
  142. m->setKernel(kernelType);
  143. m->setDegree(0);
  144. m->setGamma(gamma);
  145. m->setCoef0(0);
  146. m->setC(c);
  147. m->setNu(nu);
  148. m->setP(0);
  149. return m;
  150. }
  151. Ptr<StatModel> loadFromFile(const string &filename) const { return SVM::load(filename); }
  152. string name() const CV_OVERRIDE { return "SVM"; }
  153. int svmType;
  154. int kernelType;
  155. double gamma;
  156. double c;
  157. double nu;
  158. };
  159. //==================================================================================================
  160. struct ML_Params_t
  161. {
  162. Ptr<IModelFactory> factory;
  163. string dataset;
  164. float mean;
  165. float sigma;
  166. };
  167. void PrintTo(const ML_Params_t & param, std::ostream *os)
  168. {
  169. *os << param.factory->name() << "_" << param.dataset;
  170. }
  171. ML_Params_t ML_Params_List[] = {
  172. { makePtr< ModelFactory<DTrees> >(), "mushroom", 0.027401f, 0.036236f },
  173. { makePtr< ModelFactory<DTrees> >(), "adult", 14.279000f, 0.354323f },
  174. { makePtr< ModelFactory<DTrees> >(), "vehicle", 29.761162f, 4.823927f },
  175. { makePtr< ModelFactory<DTrees> >(), "abalone", 7.297540f, 0.510058f },
  176. { makePtr< ModelFactory<Boost> >(Boost::REAL), "adult", 13.894001f, 0.337763f },
  177. { makePtr< ModelFactory<Boost> >(Boost::DISCRETE), "mushroom", 0.007274f, 0.029400f },
  178. { makePtr< ModelFactory<Boost> >(Boost::LOGIT), "ringnorm", 9.993943f, 0.860256f },
  179. { makePtr< ModelFactory<Boost> >(Boost::GENTLE), "spambase", 5.404347f, 0.581716f },
  180. { makePtr< ModelFactory<RTrees> >(), "waveform", 17.100641f, 0.630052f },
  181. { makePtr< ModelFactory<RTrees> >(), "mushroom", 0.006547f, 0.028248f },
  182. { makePtr< ModelFactory<RTrees> >(), "adult", 13.5129f, 0.266065f },
  183. { makePtr< ModelFactory<RTrees> >(), "abalone", 4.745199f, 0.282112f },
  184. { makePtr< ModelFactory<RTrees> >(), "vehicle", 24.964712f, 4.469287f },
  185. { makePtr< ModelFactory<RTrees> >(), "letter", 5.334999f, 0.261142f },
  186. { makePtr< ModelFactory<RTrees> >(), "ringnorm", 6.248733f, 0.904713f },
  187. { makePtr< ModelFactory<RTrees> >(), "twonorm", 4.506479f, 0.449739f },
  188. { makePtr< ModelFactory<RTrees> >(), "spambase", 5.243477f, 0.54232f },
  189. };
  190. typedef testing::TestWithParam<ML_Params_t> ML_Params;
  191. TEST_P(ML_Params, accuracy)
  192. {
  193. const ML_Params_t & param = GetParam();
  194. DatasetDesc &dataset = getDataset(param.dataset);
  195. Ptr<TrainData> data = dataset.load();
  196. ASSERT_TRUE(data);
  197. ASSERT_TRUE(data->getNSamples() > 0);
  198. Ptr<StatModel> m = param.factory->createNew(dataset);
  199. ASSERT_TRUE(m);
  200. ASSERT_TRUE(m->train(data, 0));
  201. float err = m->calcError(data, true, noArray());
  202. EXPECT_NEAR(err, param.mean, 4 * param.sigma);
  203. }
  204. INSTANTIATE_TEST_CASE_P(/**/, ML_Params, testing::ValuesIn(ML_Params_List));
  205. //==================================================================================================
  206. struct ML_SL_Params_t
  207. {
  208. Ptr<IModelFactory> factory;
  209. string dataset;
  210. };
  211. void PrintTo(const ML_SL_Params_t & param, std::ostream *os)
  212. {
  213. *os << param.factory->name() << "_" << param.dataset;
  214. }
  215. ML_SL_Params_t ML_SL_Params_List[] = {
  216. { makePtr< ModelFactory<NormalBayesClassifier> >(), "waveform" },
  217. { makePtr< ModelFactory<KNearest> >(), "waveform" },
  218. { makePtr< ModelFactory<KNearest> >(), "abalone" },
  219. { makePtr< ModelFactory<SVM> >(SVM::C_SVC, SVM::LINEAR, 1, 0.5, 0), "waveform" },
  220. { makePtr< ModelFactory<SVM> >(SVM::NU_SVR, SVM::RBF, 0.00225, 62.5, 0.03), "poletelecomm" },
  221. { makePtr< ModelFactory<DTrees> >(), "mushroom" },
  222. { makePtr< ModelFactory<DTrees> >(), "abalone" },
  223. { makePtr< ModelFactory<Boost> >(Boost::REAL), "adult" },
  224. { makePtr< ModelFactory<RTrees> >(), "waveform" },
  225. { makePtr< ModelFactory<RTrees> >(), "abalone" },
  226. { makePtr< ModelFactory<SVMSGD> >(), "waveform" },
  227. };
  228. typedef testing::TestWithParam<ML_SL_Params_t> ML_SL_Params;
  229. TEST_P(ML_SL_Params, save_load)
  230. {
  231. const ML_SL_Params_t & param = GetParam();
  232. DatasetDesc &dataset = getDataset(param.dataset);
  233. Ptr<TrainData> data = dataset.load();
  234. ASSERT_TRUE(data);
  235. ASSERT_TRUE(data->getNSamples() > 0);
  236. Mat responses1, responses2;
  237. string file1 = tempfile(".json.gz");
  238. string file2 = tempfile(".json.gz");
  239. {
  240. Ptr<StatModel> m = param.factory->createNew(dataset);
  241. ASSERT_TRUE(m);
  242. ASSERT_TRUE(m->train(data, 0));
  243. m->calcError(data, true, responses1);
  244. m->save(file1 + "?base64");
  245. }
  246. {
  247. Ptr<StatModel> m = param.factory->loadFromFile(file1);
  248. ASSERT_TRUE(m);
  249. m->calcError(data, true, responses2);
  250. m->save(file2 + "?base64");
  251. }
  252. EXPECT_MAT_NEAR(responses1, responses2, 0.0);
  253. {
  254. ifstream f1(file1.c_str(), std::ios_base::binary);
  255. ifstream f2(file2.c_str(), std::ios_base::binary);
  256. ASSERT_TRUE(f1.is_open() && f2.is_open());
  257. const size_t BUFSZ = 10000;
  258. vector<char> buf1(BUFSZ, 0);
  259. vector<char> buf2(BUFSZ, 0);
  260. while (true)
  261. {
  262. f1.read(&buf1[0], BUFSZ);
  263. f2.read(&buf2[0], BUFSZ);
  264. EXPECT_EQ(f1.gcount(), f2.gcount());
  265. EXPECT_EQ(f1.eof(), f2.eof());
  266. if (!f1.good() || !f2.good() || f1.gcount() != f2.gcount())
  267. break;
  268. ASSERT_EQ(buf1, buf2);
  269. }
  270. }
  271. remove(file1.c_str());
  272. remove(file2.c_str());
  273. }
  274. INSTANTIATE_TEST_CASE_P(/**/, ML_SL_Params, testing::ValuesIn(ML_SL_Params_List));
  275. //==================================================================================================
  276. TEST(TrainDataGet, layout_ROW_SAMPLE) // Details: #12236
  277. {
  278. cv::Mat test = cv::Mat::ones(150, 30, CV_32FC1) * 2;
  279. test.col(3) += Scalar::all(3);
  280. cv::Mat labels = cv::Mat::ones(150, 3, CV_32SC1) * 5;
  281. labels.col(1) += 1;
  282. cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(test, cv::ml::ROW_SAMPLE, labels);
  283. train_data->setTrainTestSplitRatio(0.9);
  284. Mat tidx = train_data->getTestSampleIdx();
  285. EXPECT_EQ((size_t)15, tidx.total());
  286. Mat tresp = train_data->getTestResponses();
  287. EXPECT_EQ(15, tresp.rows);
  288. EXPECT_EQ(labels.cols, tresp.cols);
  289. EXPECT_EQ(5, tresp.at<int>(0, 0)) << tresp;
  290. EXPECT_EQ(6, tresp.at<int>(0, 1)) << tresp;
  291. EXPECT_EQ(6, tresp.at<int>(14, 1)) << tresp;
  292. EXPECT_EQ(5, tresp.at<int>(14, 2)) << tresp;
  293. Mat tsamples = train_data->getTestSamples();
  294. EXPECT_EQ(15, tsamples.rows);
  295. EXPECT_EQ(test.cols, tsamples.cols);
  296. EXPECT_EQ(2, tsamples.at<float>(0, 0)) << tsamples;
  297. EXPECT_EQ(5, tsamples.at<float>(0, 3)) << tsamples;
  298. EXPECT_EQ(2, tsamples.at<float>(14, test.cols - 1)) << tsamples;
  299. EXPECT_EQ(5, tsamples.at<float>(14, 3)) << tsamples;
  300. }
  301. TEST(TrainDataGet, layout_COL_SAMPLE) // Details: #12236
  302. {
  303. cv::Mat test = cv::Mat::ones(30, 150, CV_32FC1) * 3;
  304. test.row(3) += Scalar::all(3);
  305. cv::Mat labels = cv::Mat::ones(3, 150, CV_32SC1) * 5;
  306. labels.row(1) += 1;
  307. cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(test, cv::ml::COL_SAMPLE, labels);
  308. train_data->setTrainTestSplitRatio(0.9);
  309. Mat tidx = train_data->getTestSampleIdx();
  310. EXPECT_EQ((size_t)15, tidx.total());
  311. Mat tresp = train_data->getTestResponses(); // always row-based, transposed
  312. EXPECT_EQ(15, tresp.rows);
  313. EXPECT_EQ(labels.rows, tresp.cols);
  314. EXPECT_EQ(5, tresp.at<int>(0, 0)) << tresp;
  315. EXPECT_EQ(6, tresp.at<int>(0, 1)) << tresp;
  316. EXPECT_EQ(6, tresp.at<int>(14, 1)) << tresp;
  317. EXPECT_EQ(5, tresp.at<int>(14, 2)) << tresp;
  318. Mat tsamples = train_data->getTestSamples();
  319. EXPECT_EQ(15, tsamples.cols);
  320. EXPECT_EQ(test.rows, tsamples.rows);
  321. EXPECT_EQ(3, tsamples.at<float>(0, 0)) << tsamples;
  322. EXPECT_EQ(6, tsamples.at<float>(3, 0)) << tsamples;
  323. EXPECT_EQ(6, tsamples.at<float>(3, 14)) << tsamples;
  324. EXPECT_EQ(3, tsamples.at<float>(test.rows - 1, 14)) << tsamples;
  325. }
  326. }} // namespace