test_bayes.cpp 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include "test_precomp.hpp"
  5. namespace opencv_test { namespace {
  6. TEST(ML_NBAYES, regression_5911)
  7. {
  8. int N=12;
  9. Ptr<ml::NormalBayesClassifier> nb = cv::ml::NormalBayesClassifier::create();
  10. // data:
  11. float X_data[] = {
  12. 1,2,3,4, 1,2,3,4, 1,2,3,4, 1,2,3,4,
  13. 5,5,5,5, 5,5,5,5, 5,5,5,5, 5,5,5,5,
  14. 4,3,2,1, 4,3,2,1, 4,3,2,1, 4,3,2,1
  15. };
  16. Mat_<float> X(N, 4, X_data);
  17. // labels:
  18. int Y_data[] = { 0,0,0,0, 1,1,1,1, 2,2,2,2 };
  19. Mat_<int> Y(N, 1, Y_data);
  20. nb->train(X, ml::ROW_SAMPLE, Y);
  21. // single prediction:
  22. Mat R1,P1;
  23. for (int i=0; i<N; i++)
  24. {
  25. Mat r,p;
  26. nb->predictProb(X.row(i), r, p);
  27. R1.push_back(r);
  28. P1.push_back(p);
  29. }
  30. // bulk prediction (continuous memory):
  31. Mat R2,P2;
  32. nb->predictProb(X, R2, P2);
  33. EXPECT_EQ(255 * R2.total(), sum(R1 == R2)[0]);
  34. EXPECT_EQ(255 * P2.total(), sum(P1 == P2)[0]);
  35. // bulk prediction, with non-continuous memory storage
  36. Mat R3_(N, 1+1, CV_32S),
  37. P3_(N, 3+1, CV_32F);
  38. nb->predictProb(X, R3_.col(0), P3_.colRange(0,3));
  39. Mat R3 = R3_.col(0).clone(),
  40. P3 = P3_.colRange(0,3).clone();
  41. EXPECT_EQ(255 * R3.total(), sum(R1 == R3)[0]);
  42. EXPECT_EQ(255 * P3.total(), sum(P1 == P3)[0]);
  43. }
  44. }} // namespace