test_lr.cpp 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. //
  5. // AUTHOR: Rahul Kavi rahulkavi[at]live[at]com
  6. //
  7. // Test data uses subset of data from the popular Iris Dataset (1936):
  8. // - http://archive.ics.uci.edu/ml/datasets/Iris
  9. // - https://en.wikipedia.org/wiki/Iris_flower_data_set
  10. //
  11. #include "test_precomp.hpp"
  12. namespace opencv_test { namespace {
  13. TEST(ML_LR, accuracy)
  14. {
  15. std::string dataFileName = findDataFile("iris.data");
  16. Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
  17. ASSERT_FALSE(tdata.empty());
  18. Ptr<LogisticRegression> p = LogisticRegression::create();
  19. p->setLearningRate(1.0);
  20. p->setIterations(10001);
  21. p->setRegularization(LogisticRegression::REG_L2);
  22. p->setTrainMethod(LogisticRegression::BATCH);
  23. p->setMiniBatchSize(10);
  24. p->train(tdata);
  25. Mat responses;
  26. p->predict(tdata->getSamples(), responses);
  27. float error = 1000;
  28. EXPECT_TRUE(calculateError(responses, tdata->getResponses(), error));
  29. EXPECT_LE(error, 0.05f);
  30. }
  31. //==================================================================================================
  32. TEST(ML_LR, save_load)
  33. {
  34. string dataFileName = findDataFile("iris.data");
  35. Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
  36. ASSERT_FALSE(tdata.empty());
  37. Mat responses1, responses2;
  38. Mat learnt_mat1, learnt_mat2;
  39. String filename = tempfile(".xml");
  40. {
  41. Ptr<LogisticRegression> lr1 = LogisticRegression::create();
  42. lr1->setLearningRate(1.0);
  43. lr1->setIterations(10001);
  44. lr1->setRegularization(LogisticRegression::REG_L2);
  45. lr1->setTrainMethod(LogisticRegression::BATCH);
  46. lr1->setMiniBatchSize(10);
  47. ASSERT_NO_THROW(lr1->train(tdata));
  48. ASSERT_NO_THROW(lr1->predict(tdata->getSamples(), responses1));
  49. ASSERT_NO_THROW(lr1->save(filename));
  50. learnt_mat1 = lr1->get_learnt_thetas();
  51. }
  52. {
  53. Ptr<LogisticRegression> lr2;
  54. ASSERT_NO_THROW(lr2 = Algorithm::load<LogisticRegression>(filename));
  55. ASSERT_NO_THROW(lr2->predict(tdata->getSamples(), responses2));
  56. learnt_mat2 = lr2->get_learnt_thetas();
  57. }
  58. // compare difference in prediction outputs and stored inputs
  59. EXPECT_MAT_NEAR(responses1, responses2, 0.f);
  60. Mat comp_learnt_mats;
  61. comp_learnt_mats = (learnt_mat1 == learnt_mat2);
  62. comp_learnt_mats = comp_learnt_mats.reshape(1, comp_learnt_mats.rows*comp_learnt_mats.cols);
  63. comp_learnt_mats.convertTo(comp_learnt_mats, CV_32S);
  64. comp_learnt_mats = comp_learnt_mats/255;
  65. // check if there is any difference between computed learnt mat and retrieved mat
  66. EXPECT_EQ(comp_learnt_mats.rows, sum(comp_learnt_mats)[0]);
  67. remove( filename.c_str() );
  68. }
  69. }} // namespace