plane_tracker.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. #!/usr/bin/env python
  2. '''
  3. Multitarget planar tracking
  4. ==================
  5. Example of using features2d framework for interactive video homography matching.
  6. ORB features and FLANN matcher are used. This sample provides PlaneTracker class
  7. and an example of its usage.
  8. video: http://www.youtube.com/watch?v=pzVbhxx6aog
  9. Usage
  10. -----
  11. plane_tracker.py [<video source>]
  12. Keys:
  13. SPACE - pause video
  14. c - clear targets
  15. Select a textured planar object to track by drawing a box with a mouse.
  16. '''
  17. # Python 2/3 compatibility
  18. from __future__ import print_function
  19. import sys
  20. PY3 = sys.version_info[0] == 3
  21. if PY3:
  22. xrange = range
  23. import numpy as np
  24. import cv2 as cv
  25. # built-in modules
  26. from collections import namedtuple
  27. # local modules
  28. import video
  29. import common
  30. from video import presets
  31. FLANN_INDEX_KDTREE = 1
  32. FLANN_INDEX_LSH = 6
  33. flann_params= dict(algorithm = FLANN_INDEX_LSH,
  34. table_number = 6, # 12
  35. key_size = 12, # 20
  36. multi_probe_level = 1) #2
  37. MIN_MATCH_COUNT = 10
  38. '''
  39. image - image to track
  40. rect - tracked rectangle (x1, y1, x2, y2)
  41. keypoints - keypoints detected inside rect
  42. descrs - their descriptors
  43. data - some user-provided data
  44. '''
  45. PlanarTarget = namedtuple('PlaneTarget', 'image, rect, keypoints, descrs, data')
  46. '''
  47. target - reference to PlanarTarget
  48. p0 - matched points coords in target image
  49. p1 - matched points coords in input frame
  50. H - homography matrix from p0 to p1
  51. quad - target boundary quad in input frame
  52. '''
  53. TrackedTarget = namedtuple('TrackedTarget', 'target, p0, p1, H, quad')
  54. class PlaneTracker:
  55. def __init__(self):
  56. self.detector = cv.ORB_create( nfeatures = 1000 )
  57. self.matcher = cv.FlannBasedMatcher(flann_params, {}) # bug : need to pass empty dict (#1329)
  58. self.targets = []
  59. self.frame_points = []
  60. def add_target(self, image, rect, data=None):
  61. '''Add a new tracking target.'''
  62. x0, y0, x1, y1 = rect
  63. raw_points, raw_descrs = self.detect_features(image)
  64. points, descs = [], []
  65. for kp, desc in zip(raw_points, raw_descrs):
  66. x, y = kp.pt
  67. if x0 <= x <= x1 and y0 <= y <= y1:
  68. points.append(kp)
  69. descs.append(desc)
  70. descs = np.uint8(descs)
  71. self.matcher.add([descs])
  72. target = PlanarTarget(image = image, rect=rect, keypoints = points, descrs=descs, data=data)
  73. self.targets.append(target)
  74. def clear(self):
  75. '''Remove all targets'''
  76. self.targets = []
  77. self.matcher.clear()
  78. def track(self, frame):
  79. '''Returns a list of detected TrackedTarget objects'''
  80. self.frame_points, frame_descrs = self.detect_features(frame)
  81. if len(self.frame_points) < MIN_MATCH_COUNT:
  82. return []
  83. matches = self.matcher.knnMatch(frame_descrs, k = 2)
  84. matches = [m[0] for m in matches if len(m) == 2 and m[0].distance < m[1].distance * 0.75]
  85. if len(matches) < MIN_MATCH_COUNT:
  86. return []
  87. matches_by_id = [[] for _ in xrange(len(self.targets))]
  88. for m in matches:
  89. matches_by_id[m.imgIdx].append(m)
  90. tracked = []
  91. for imgIdx, matches in enumerate(matches_by_id):
  92. if len(matches) < MIN_MATCH_COUNT:
  93. continue
  94. target = self.targets[imgIdx]
  95. p0 = [target.keypoints[m.trainIdx].pt for m in matches]
  96. p1 = [self.frame_points[m.queryIdx].pt for m in matches]
  97. p0, p1 = np.float32((p0, p1))
  98. H, status = cv.findHomography(p0, p1, cv.RANSAC, 3.0)
  99. status = status.ravel() != 0
  100. if status.sum() < MIN_MATCH_COUNT:
  101. continue
  102. p0, p1 = p0[status], p1[status]
  103. x0, y0, x1, y1 = target.rect
  104. quad = np.float32([[x0, y0], [x1, y0], [x1, y1], [x0, y1]])
  105. quad = cv.perspectiveTransform(quad.reshape(1, -1, 2), H).reshape(-1, 2)
  106. track = TrackedTarget(target=target, p0=p0, p1=p1, H=H, quad=quad)
  107. tracked.append(track)
  108. tracked.sort(key = lambda t: len(t.p0), reverse=True)
  109. return tracked
  110. def detect_features(self, frame):
  111. '''detect_features(self, frame) -> keypoints, descrs'''
  112. keypoints, descrs = self.detector.detectAndCompute(frame, None)
  113. if descrs is None: # detectAndCompute returns descs=None if not keypoints found
  114. descrs = []
  115. return keypoints, descrs
  116. class App:
  117. def __init__(self, src):
  118. self.cap = video.create_capture(src, presets['book'])
  119. self.frame = None
  120. self.paused = False
  121. self.tracker = PlaneTracker()
  122. cv.namedWindow('plane')
  123. self.rect_sel = common.RectSelector('plane', self.on_rect)
  124. def on_rect(self, rect):
  125. self.tracker.add_target(self.frame, rect)
  126. def run(self):
  127. while True:
  128. playing = not self.paused and not self.rect_sel.dragging
  129. if playing or self.frame is None:
  130. ret, frame = self.cap.read()
  131. if not ret:
  132. break
  133. self.frame = frame.copy()
  134. vis = self.frame.copy()
  135. if playing:
  136. tracked = self.tracker.track(self.frame)
  137. for tr in tracked:
  138. cv.polylines(vis, [np.int32(tr.quad)], True, (255, 255, 255), 2)
  139. for (x, y) in np.int32(tr.p1):
  140. cv.circle(vis, (x, y), 2, (255, 255, 255))
  141. self.rect_sel.draw(vis)
  142. cv.imshow('plane', vis)
  143. ch = cv.waitKey(1)
  144. if ch == ord(' '):
  145. self.paused = not self.paused
  146. if ch == ord('c'):
  147. self.tracker.clear()
  148. if ch == 27:
  149. break
  150. if __name__ == '__main__':
  151. print(__doc__)
  152. import sys
  153. try:
  154. video_src = sys.argv[1]
  155. except:
  156. video_src = 0
  157. App(video_src).run()