test_ann.cpp 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include "test_precomp.hpp"
  5. // #define GENERATE_TESTDATA
  6. namespace opencv_test { namespace {
  7. struct Activation
  8. {
  9. int id;
  10. const char * name;
  11. };
  12. void PrintTo(const Activation &a, std::ostream *os) { *os << a.name; }
  13. Activation activation_list[] =
  14. {
  15. { ml::ANN_MLP::IDENTITY, "identity" },
  16. { ml::ANN_MLP::SIGMOID_SYM, "sigmoid_sym" },
  17. { ml::ANN_MLP::GAUSSIAN, "gaussian" },
  18. { ml::ANN_MLP::RELU, "relu" },
  19. { ml::ANN_MLP::LEAKYRELU, "leakyrelu" },
  20. };
  21. typedef testing::TestWithParam< Activation > ML_ANN_Params;
  22. TEST_P(ML_ANN_Params, ActivationFunction)
  23. {
  24. const Activation &activation = GetParam();
  25. const string dataname = "waveform";
  26. const string data_path = findDataFile(dataname + ".data");
  27. const string model_name = dataname + "_" + activation.name + ".yml";
  28. Ptr<TrainData> tdata = TrainData::loadFromCSV(data_path, 0);
  29. ASSERT_FALSE(tdata.empty());
  30. // hack?
  31. const uint64 old_state = theRNG().state;
  32. theRNG().state = 1027401484159173092;
  33. tdata->setTrainTestSplit(500);
  34. theRNG().state = old_state;
  35. Mat_<int> layerSizes(1, 4);
  36. layerSizes(0, 0) = tdata->getNVars();
  37. layerSizes(0, 1) = 100;
  38. layerSizes(0, 2) = 100;
  39. layerSizes(0, 3) = tdata->getResponses().cols;
  40. Mat testSamples = tdata->getTestSamples();
  41. Mat rx, ry;
  42. {
  43. Ptr<ml::ANN_MLP> x = ml::ANN_MLP::create();
  44. x->setActivationFunction(activation.id);
  45. x->setLayerSizes(layerSizes);
  46. x->setTrainMethod(ml::ANN_MLP::RPROP, 0.01, 0.1);
  47. x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 300, 0.01));
  48. x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE);
  49. ASSERT_TRUE(x->isTrained());
  50. x->predict(testSamples, rx);
  51. #ifdef GENERATE_TESTDATA
  52. x->save(cvtest::TS::ptr()->get_data_path() + model_name);
  53. #endif
  54. }
  55. {
  56. const string model_path = findDataFile(model_name);
  57. Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(model_path);
  58. ASSERT_TRUE(y);
  59. y->predict(testSamples, ry);
  60. EXPECT_MAT_NEAR(rx, ry, FLT_EPSILON);
  61. }
  62. }
  63. INSTANTIATE_TEST_CASE_P(/**/, ML_ANN_Params, testing::ValuesIn(activation_list));
  64. //==================================================================================================
  65. CV_ENUM(ANN_MLP_METHOD, ANN_MLP::RPROP, ANN_MLP::ANNEAL)
  66. typedef tuple<ANN_MLP_METHOD, string, int> ML_ANN_METHOD_Params;
  67. typedef TestWithParam<ML_ANN_METHOD_Params> ML_ANN_METHOD;
  68. TEST_P(ML_ANN_METHOD, Test)
  69. {
  70. int methodType = get<0>(GetParam());
  71. string methodName = get<1>(GetParam());
  72. int N = get<2>(GetParam());
  73. String folder = string(cvtest::TS::ptr()->get_data_path());
  74. String original_path = findDataFile("waveform.data");
  75. string dataname = "waveform_" + methodName;
  76. string weight_name = dataname + "_init_weight.yml.gz";
  77. string model_name = dataname + ".yml.gz";
  78. string response_name = dataname + "_response.yml.gz";
  79. Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0);
  80. ASSERT_FALSE(tdata2.empty());
  81. Mat samples = tdata2->getSamples()(Range(0, N), Range::all());
  82. Mat responses(N, 3, CV_32FC1, Scalar(0));
  83. for (int i = 0; i < N; i++)
  84. responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1;
  85. Ptr<TrainData> tdata = TrainData::create(samples, ml::ROW_SAMPLE, responses);
  86. ASSERT_FALSE(tdata.empty());
  87. // hack?
  88. const uint64 old_state = theRNG().state;
  89. theRNG().state = 0;
  90. tdata->setTrainTestSplitRatio(0.8);
  91. theRNG().state = old_state;
  92. Mat testSamples = tdata->getTestSamples();
  93. // train 1st stage
  94. Ptr<ml::ANN_MLP> xx = ml::ANN_MLP::create();
  95. Mat_<int> layerSizes(1, 4);
  96. layerSizes(0, 0) = tdata->getNVars();
  97. layerSizes(0, 1) = 30;
  98. layerSizes(0, 2) = 30;
  99. layerSizes(0, 3) = tdata->getResponses().cols;
  100. xx->setLayerSizes(layerSizes);
  101. xx->setActivationFunction(ml::ANN_MLP::SIGMOID_SYM);
  102. xx->setTrainMethod(ml::ANN_MLP::RPROP);
  103. xx->setTermCriteria(TermCriteria(TermCriteria::COUNT, 1, 0.01));
  104. xx->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE);
  105. #ifdef GENERATE_TESTDATA
  106. {
  107. FileStorage fs;
  108. fs.open(cvtest::TS::ptr()->get_data_path() + weight_name, FileStorage::WRITE + FileStorage::BASE64);
  109. xx->write(fs);
  110. }
  111. #endif
  112. // train 2nd stage
  113. Mat r_gold;
  114. Ptr<ml::ANN_MLP> x = ml::ANN_MLP::create();
  115. {
  116. const string weight_file = findDataFile(weight_name);
  117. FileStorage fs;
  118. fs.open(weight_file, FileStorage::READ);
  119. x->read(fs.root());
  120. }
  121. x->setTrainMethod(methodType);
  122. if (methodType == ml::ANN_MLP::ANNEAL)
  123. {
  124. x->setAnnealEnergyRNG(RNG(CV_BIG_INT(0xffffffff)));
  125. x->setAnnealInitialT(12);
  126. x->setAnnealFinalT(0.15);
  127. x->setAnnealCoolingRatio(0.96);
  128. x->setAnnealItePerStep(11);
  129. }
  130. x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 100, 0.01));
  131. x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS);
  132. ASSERT_TRUE(x->isTrained());
  133. #ifdef GENERATE_TESTDATA
  134. x->save(cvtest::TS::ptr()->get_data_path() + model_name);
  135. x->predict(testSamples, r_gold);
  136. {
  137. FileStorage fs_response(cvtest::TS::ptr()->get_data_path() + response_name, FileStorage::WRITE + FileStorage::BASE64);
  138. fs_response << "response" << r_gold;
  139. }
  140. #endif
  141. {
  142. const string response_file = findDataFile(response_name);
  143. FileStorage fs_response(response_file, FileStorage::READ);
  144. fs_response["response"] >> r_gold;
  145. }
  146. ASSERT_FALSE(r_gold.empty());
  147. // verify
  148. const string model_file = findDataFile(model_name);
  149. Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(model_file);
  150. ASSERT_TRUE(y);
  151. Mat rx, ry;
  152. for (int j = 0; j < 4; j++)
  153. {
  154. rx = x->getWeights(j);
  155. ry = y->getWeights(j);
  156. EXPECT_MAT_NEAR(rx, ry, FLT_EPSILON) << "Weights are not equal for layer: " << j;
  157. }
  158. x->predict(testSamples, rx);
  159. y->predict(testSamples, ry);
  160. EXPECT_MAT_NEAR(ry, rx, FLT_EPSILON) << "Predict are not equal to result of the saved model";
  161. EXPECT_MAT_NEAR(r_gold, rx, FLT_EPSILON) << "Predict are not equal to 'gold' response";
  162. }
  163. INSTANTIATE_TEST_CASE_P(/*none*/, ML_ANN_METHOD,
  164. testing::Values(
  165. ML_ANN_METHOD_Params(ml::ANN_MLP::RPROP, "rprop", 5000),
  166. ML_ANN_METHOD_Params(ml::ANN_MLP::ANNEAL, "anneal", 1000)
  167. // ML_ANN_METHOD_Params(ml::ANN_MLP::BACKPROP, "backprop", 500) -----> NO BACKPROP TEST
  168. )
  169. );
  170. }} // namespace