test_registration.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import unittest
  2. import os
  3. import sys
  4. import numpy as np
  5. import cv2 as cv
  6. sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__),
  7. '..', '..')))
  8. from opencv_stitching.feature_detector import FeatureDetector
  9. from opencv_stitching.feature_matcher import FeatureMatcher
  10. from opencv_stitching.subsetter import Subsetter
  11. class TestImageRegistration(unittest.TestCase):
  12. def test_feature_detector(self):
  13. img1 = cv.imread("s1.jpg")
  14. default_number_of_keypoints = 500
  15. detector = FeatureDetector("orb")
  16. features = detector.detect_features(img1)
  17. self.assertEqual(len(features.getKeypoints()),
  18. default_number_of_keypoints)
  19. other_keypoints = 1000
  20. detector = FeatureDetector("orb", nfeatures=other_keypoints)
  21. features = detector.detect_features(img1)
  22. self.assertEqual(len(features.getKeypoints()), other_keypoints)
  23. def test_feature_matcher(self):
  24. img1, img2 = cv.imread("s1.jpg"), cv.imread("s2.jpg")
  25. detector = FeatureDetector("orb")
  26. features = [detector.detect_features(img1),
  27. detector.detect_features(img2)]
  28. matcher = FeatureMatcher()
  29. pairwise_matches = matcher.match_features(features)
  30. self.assertEqual(len(pairwise_matches), len(features)**2)
  31. self.assertGreater(pairwise_matches[1].confidence, 2)
  32. matches_matrix = FeatureMatcher.get_matches_matrix(pairwise_matches)
  33. self.assertEqual(matches_matrix.shape, (2, 2))
  34. conf_matrix = FeatureMatcher.get_confidence_matrix(pairwise_matches)
  35. self.assertTrue(np.array_equal(
  36. conf_matrix > 2,
  37. np.array([[False, True], [True, False]])
  38. ))
  39. def test_subsetting(self):
  40. img1, img2 = cv.imread("s1.jpg"), cv.imread("s2.jpg")
  41. img3, img4 = cv.imread("boat1.jpg"), cv.imread("boat2.jpg")
  42. img5 = cv.imread("boat3.jpg")
  43. img_names = ["s1.jpg", "s2.jpg", "boat1.jpg", "boat2.jpg", "boat3.jpg"]
  44. detector = FeatureDetector("orb")
  45. features = [detector.detect_features(img1),
  46. detector.detect_features(img2),
  47. detector.detect_features(img3),
  48. detector.detect_features(img4),
  49. detector.detect_features(img5)]
  50. matcher = FeatureMatcher()
  51. pairwise_matches = matcher.match_features(features)
  52. subsetter = Subsetter(confidence_threshold=1,
  53. matches_graph_dot_file="dot_graph.txt") # view in https://dreampuf.github.io # noqa
  54. indices = subsetter.get_indices_to_keep(features, pairwise_matches)
  55. indices_to_delete = subsetter.get_indices_to_delete(len(img_names),
  56. indices)
  57. np.testing.assert_array_equal(indices, np.array([2, 3, 4]))
  58. np.testing.assert_array_equal(indices_to_delete, np.array([0, 1]))
  59. subsetted_image_names = subsetter.subset_list(img_names, indices)
  60. self.assertEqual(subsetted_image_names,
  61. ['boat1.jpg', 'boat2.jpg', 'boat3.jpg'])
  62. matches_subset = subsetter.subset_matches(pairwise_matches, indices)
  63. # FeatureMatcher.get_confidence_matrix(pairwise_matches)
  64. # FeatureMatcher.get_confidence_matrix(subsetted_matches)
  65. self.assertEqual(pairwise_matches[13].confidence,
  66. matches_subset[1].confidence)
  67. graph = subsetter.get_matches_graph(img_names, pairwise_matches)
  68. self.assertTrue(graph.startswith("graph matches_graph{"))
  69. subsetter.save_matches_graph_dot_file(img_names, pairwise_matches)
  70. with open('dot_graph.txt', 'r') as file:
  71. graph = file.read()
  72. self.assertTrue(graph.startswith("graph matches_graph{"))
  73. def starttest():
  74. unittest.main()
  75. if __name__ == "__main__":
  76. starttest()