MLTest.java 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. package org.opencv.test.ml;
  2. import org.opencv.ml.Ml;
  3. import org.opencv.ml.SVM;
  4. import org.opencv.core.Mat;
  5. import org.opencv.core.MatOfFloat;
  6. import org.opencv.core.MatOfInt;
  7. import org.opencv.core.CvType;
  8. import org.opencv.test.OpenCVTestCase;
  9. import org.opencv.test.OpenCVTestRunner;
  10. public class MLTest extends OpenCVTestCase {
  11. public void testSaveLoad() {
  12. Mat samples = new MatOfFloat(new float[] {
  13. 5.1f, 3.5f, 1.4f, 0.2f,
  14. 4.9f, 3.0f, 1.4f, 0.2f,
  15. 4.7f, 3.2f, 1.3f, 0.2f,
  16. 4.6f, 3.1f, 1.5f, 0.2f,
  17. 5.0f, 3.6f, 1.4f, 0.2f,
  18. 7.0f, 3.2f, 4.7f, 1.4f,
  19. 6.4f, 3.2f, 4.5f, 1.5f,
  20. 6.9f, 3.1f, 4.9f, 1.5f,
  21. 5.5f, 2.3f, 4.0f, 1.3f,
  22. 6.5f, 2.8f, 4.6f, 1.5f
  23. }).reshape(1, 10);
  24. Mat responses = new MatOfInt(new int[] {
  25. 0, 0, 0, 0, 0, 1, 1, 1, 1, 1
  26. }).reshape(1, 10);
  27. SVM saved = SVM.create();
  28. assertFalse(saved.isTrained());
  29. saved.train(samples, Ml.ROW_SAMPLE, responses);
  30. assertTrue(saved.isTrained());
  31. String filename = OpenCVTestRunner.getTempFileName("yml");
  32. saved.save(filename);
  33. SVM loaded = SVM.load(filename);
  34. assertTrue(loaded.isTrained());
  35. }
  36. }