test_feature_homography.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #!/usr/bin/env python
  2. '''
  3. Feature homography
  4. ==================
  5. Example of using features2d framework for interactive video homography matching.
  6. ORB features and FLANN matcher are used. The actual tracking is implemented by
  7. PlaneTracker class in plane_tracker.py
  8. '''
  9. # Python 2/3 compatibility
  10. from __future__ import print_function
  11. import numpy as np
  12. import cv2 as cv
  13. import sys
  14. PY3 = sys.version_info[0] == 3
  15. if PY3:
  16. xrange = range
  17. # local modules
  18. from tst_scene_render import TestSceneRender
  19. def intersectionRate(s1, s2):
  20. x1, y1, x2, y2 = s1
  21. s1 = np.array([[x1, y1], [x2,y1], [x2, y2], [x1, y2]])
  22. area, _intersection = cv.intersectConvexConvex(s1, np.array(s2))
  23. return 2 * area / (cv.contourArea(s1) + cv.contourArea(np.array(s2)))
  24. from tests_common import NewOpenCVTests
  25. class feature_homography_test(NewOpenCVTests):
  26. render = None
  27. tracker = None
  28. framesCounter = 0
  29. frame = None
  30. def test_feature_homography(self):
  31. self.render = TestSceneRender(self.get_sample('samples/data/graf1.png'),
  32. self.get_sample('samples/data/box.png'), noise = 0.5, speed = 0.5)
  33. self.frame = self.render.getNextFrame()
  34. self.tracker = PlaneTracker()
  35. self.tracker.clear()
  36. self.tracker.add_target(self.frame, self.render.getCurrentRect())
  37. while self.framesCounter < 100:
  38. self.framesCounter += 1
  39. tracked = self.tracker.track(self.frame)
  40. if len(tracked) > 0:
  41. tracked = tracked[0]
  42. self.assertGreater(intersectionRate(self.render.getCurrentRect(), np.int32(tracked.quad)), 0.6)
  43. else:
  44. self.assertEqual(0, 1, 'Tracking error')
  45. self.frame = self.render.getNextFrame()
  46. # built-in modules
  47. from collections import namedtuple
  48. FLANN_INDEX_KDTREE = 1
  49. FLANN_INDEX_LSH = 6
  50. flann_params= dict(algorithm = FLANN_INDEX_LSH,
  51. table_number = 6, # 12
  52. key_size = 12, # 20
  53. multi_probe_level = 1) #2
  54. MIN_MATCH_COUNT = 10
  55. '''
  56. image - image to track
  57. rect - tracked rectangle (x1, y1, x2, y2)
  58. keypoints - keypoints detected inside rect
  59. descrs - their descriptors
  60. data - some user-provided data
  61. '''
  62. PlanarTarget = namedtuple('PlaneTarget', 'image, rect, keypoints, descrs, data')
  63. '''
  64. target - reference to PlanarTarget
  65. p0 - matched points coords in target image
  66. p1 - matched points coords in input frame
  67. H - homography matrix from p0 to p1
  68. quad - target boundary quad in input frame
  69. '''
  70. TrackedTarget = namedtuple('TrackedTarget', 'target, p0, p1, H, quad')
  71. class PlaneTracker:
  72. def __init__(self):
  73. self.detector = cv.AKAZE_create(threshold = 0.003)
  74. self.matcher = cv.FlannBasedMatcher(flann_params, {}) # bug : need to pass empty dict (#1329)
  75. self.targets = []
  76. self.frame_points = []
  77. def add_target(self, image, rect, data=None):
  78. '''Add a new tracking target.'''
  79. x0, y0, x1, y1 = rect
  80. raw_points, raw_descrs = self.detect_features(image)
  81. points, descs = [], []
  82. for kp, desc in zip(raw_points, raw_descrs):
  83. x, y = kp.pt
  84. if x0 <= x <= x1 and y0 <= y <= y1:
  85. points.append(kp)
  86. descs.append(desc)
  87. descs = np.uint8(descs)
  88. self.matcher.add([descs])
  89. target = PlanarTarget(image = image, rect=rect, keypoints = points, descrs=descs, data=data)
  90. self.targets.append(target)
  91. def clear(self):
  92. '''Remove all targets'''
  93. self.targets = []
  94. self.matcher.clear()
  95. def track(self, frame):
  96. '''Returns a list of detected TrackedTarget objects'''
  97. self.frame_points, frame_descrs = self.detect_features(frame)
  98. if len(self.frame_points) < MIN_MATCH_COUNT:
  99. return []
  100. matches = self.matcher.knnMatch(frame_descrs, k = 2)
  101. matches = [m[0] for m in matches if len(m) == 2 and m[0].distance < m[1].distance * 0.75]
  102. if len(matches) < MIN_MATCH_COUNT:
  103. return []
  104. matches_by_id = [[] for _ in xrange(len(self.targets))]
  105. for m in matches:
  106. matches_by_id[m.imgIdx].append(m)
  107. tracked = []
  108. for imgIdx, matches in enumerate(matches_by_id):
  109. if len(matches) < MIN_MATCH_COUNT:
  110. continue
  111. target = self.targets[imgIdx]
  112. p0 = [target.keypoints[m.trainIdx].pt for m in matches]
  113. p1 = [self.frame_points[m.queryIdx].pt for m in matches]
  114. p0, p1 = np.float32((p0, p1))
  115. H, status = cv.findHomography(p0, p1, cv.RANSAC, 3.0)
  116. status = status.ravel() != 0
  117. if status.sum() < MIN_MATCH_COUNT:
  118. continue
  119. p0, p1 = p0[status], p1[status]
  120. x0, y0, x1, y1 = target.rect
  121. quad = np.float32([[x0, y0], [x1, y0], [x1, y1], [x0, y1]])
  122. quad = cv.perspectiveTransform(quad.reshape(1, -1, 2), H).reshape(-1, 2)
  123. track = TrackedTarget(target=target, p0=p0, p1=p1, H=H, quad=quad)
  124. tracked.append(track)
  125. tracked.sort(key = lambda t: len(t.p0), reverse=True)
  126. return tracked
  127. def detect_features(self, frame):
  128. '''detect_features(self, frame) -> keypoints, descrs'''
  129. keypoints, descrs = self.detector.detectAndCompute(frame, None)
  130. if descrs is None: # detectAndCompute returns descs=None if no keypoints found
  131. descrs = []
  132. return keypoints, descrs
  133. if __name__ == '__main__':
  134. NewOpenCVTests.bootstrap()