test_save_load.cpp 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include "test_precomp.hpp"
  5. namespace opencv_test { namespace {
  6. void randomFillCategories(const string & filename, Mat & input)
  7. {
  8. Mat catMap;
  9. Mat catCount;
  10. std::vector<uchar> varTypes;
  11. FileStorage fs(filename, FileStorage::READ);
  12. FileNode root = fs.getFirstTopLevelNode();
  13. root["cat_map"] >> catMap;
  14. root["cat_count"] >> catCount;
  15. root["var_type"] >> varTypes;
  16. int offset = 0;
  17. int countOffset = 0;
  18. uint var = 0, varCount = (uint)varTypes.size();
  19. for (; var < varCount; ++var)
  20. {
  21. if (varTypes[var] == ml::VAR_CATEGORICAL)
  22. {
  23. int size = catCount.at<int>(0, countOffset);
  24. for (int row = 0; row < input.rows; ++row)
  25. {
  26. int randomChosenIndex = offset + ((uint)cv::theRNG()) % size;
  27. int value = catMap.at<int>(0, randomChosenIndex);
  28. input.at<float>(row, var) = (float)value;
  29. }
  30. offset += size;
  31. ++countOffset;
  32. }
  33. }
  34. }
  35. //==================================================================================================
  36. typedef tuple<string, string> ML_Legacy_Param;
  37. typedef testing::TestWithParam< ML_Legacy_Param > ML_Legacy_Params;
  38. TEST_P(ML_Legacy_Params, legacy_load)
  39. {
  40. const string modelName = get<0>(GetParam());
  41. const string dataName = get<1>(GetParam());
  42. const string filename = findDataFile("legacy/" + modelName + "_" + dataName + ".xml");
  43. const bool isTree = modelName == CV_BOOST || modelName == CV_DTREE || modelName == CV_RTREES;
  44. Ptr<StatModel> model;
  45. if (modelName == CV_BOOST)
  46. model = Algorithm::load<Boost>(filename);
  47. else if (modelName == CV_ANN)
  48. model = Algorithm::load<ANN_MLP>(filename);
  49. else if (modelName == CV_DTREE)
  50. model = Algorithm::load<DTrees>(filename);
  51. else if (modelName == CV_NBAYES)
  52. model = Algorithm::load<NormalBayesClassifier>(filename);
  53. else if (modelName == CV_SVM)
  54. model = Algorithm::load<SVM>(filename);
  55. else if (modelName == CV_RTREES)
  56. model = Algorithm::load<RTrees>(filename);
  57. else if (modelName == CV_SVMSGD)
  58. model = Algorithm::load<SVMSGD>(filename);
  59. ASSERT_TRUE(model);
  60. Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F);
  61. cv::theRNG().fill(input, RNG::UNIFORM, 0, 40);
  62. if (isTree)
  63. randomFillCategories(filename, input);
  64. Mat output;
  65. EXPECT_NO_THROW(model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0)));
  66. // just check if no internal assertions or errors thrown
  67. }
  68. ML_Legacy_Param param_list[] = {
  69. ML_Legacy_Param(CV_ANN, "waveform"),
  70. ML_Legacy_Param(CV_BOOST, "adult"),
  71. ML_Legacy_Param(CV_BOOST, "1"),
  72. ML_Legacy_Param(CV_BOOST, "2"),
  73. ML_Legacy_Param(CV_BOOST, "3"),
  74. ML_Legacy_Param(CV_DTREE, "abalone"),
  75. ML_Legacy_Param(CV_DTREE, "mushroom"),
  76. ML_Legacy_Param(CV_NBAYES, "waveform"),
  77. ML_Legacy_Param(CV_SVM, "poletelecomm"),
  78. ML_Legacy_Param(CV_SVM, "waveform"),
  79. ML_Legacy_Param(CV_RTREES, "waveform"),
  80. ML_Legacy_Param(CV_SVMSGD, "waveform"),
  81. };
  82. INSTANTIATE_TEST_CASE_P(/**/, ML_Legacy_Params, testing::ValuesIn(param_list));
  83. /*TEST(ML_SVM, throw_exception_when_save_untrained_model)
  84. {
  85. Ptr<cv::ml::SVM> svm;
  86. string filename = tempfile("svm.xml");
  87. ASSERT_THROW(svm.save(filename.c_str()), Exception);
  88. remove(filename.c_str());
  89. }*/
  90. }} // namespace