mosse.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. #!/usr/bin/env python
  2. '''
  3. MOSSE tracking sample
  4. This sample implements correlation-based tracking approach, described in [1].
  5. Usage:
  6. mosse.py [--pause] [<video source>]
  7. --pause - Start with playback paused at the first video frame.
  8. Useful for tracking target selection.
  9. Draw rectangles around objects with a mouse to track them.
  10. Keys:
  11. SPACE - pause video
  12. c - clear targets
  13. [1] David S. Bolme et al. "Visual Object Tracking using Adaptive Correlation Filters"
  14. http://www.cs.colostate.edu/~draper/papers/bolme_cvpr10.pdf
  15. '''
  16. # Python 2/3 compatibility
  17. from __future__ import print_function
  18. import sys
  19. PY3 = sys.version_info[0] == 3
  20. if PY3:
  21. xrange = range
  22. import numpy as np
  23. import cv2 as cv
  24. from common import draw_str, RectSelector
  25. import video
  26. def rnd_warp(a):
  27. h, w = a.shape[:2]
  28. T = np.zeros((2, 3))
  29. coef = 0.2
  30. ang = (np.random.rand()-0.5)*coef
  31. c, s = np.cos(ang), np.sin(ang)
  32. T[:2, :2] = [[c,-s], [s, c]]
  33. T[:2, :2] += (np.random.rand(2, 2) - 0.5)*coef
  34. c = (w/2, h/2)
  35. T[:,2] = c - np.dot(T[:2, :2], c)
  36. return cv.warpAffine(a, T, (w, h), borderMode = cv.BORDER_REFLECT)
  37. def divSpec(A, B):
  38. Ar, Ai = A[...,0], A[...,1]
  39. Br, Bi = B[...,0], B[...,1]
  40. C = (Ar+1j*Ai)/(Br+1j*Bi)
  41. C = np.dstack([np.real(C), np.imag(C)]).copy()
  42. return C
  43. eps = 1e-5
  44. class MOSSE:
  45. def __init__(self, frame, rect):
  46. x1, y1, x2, y2 = rect
  47. w, h = map(cv.getOptimalDFTSize, [x2-x1, y2-y1])
  48. x1, y1 = (x1+x2-w)//2, (y1+y2-h)//2
  49. self.pos = x, y = x1+0.5*(w-1), y1+0.5*(h-1)
  50. self.size = w, h
  51. img = cv.getRectSubPix(frame, (w, h), (x, y))
  52. self.win = cv.createHanningWindow((w, h), cv.CV_32F)
  53. g = np.zeros((h, w), np.float32)
  54. g[h//2, w//2] = 1
  55. g = cv.GaussianBlur(g, (-1, -1), 2.0)
  56. g /= g.max()
  57. self.G = cv.dft(g, flags=cv.DFT_COMPLEX_OUTPUT)
  58. self.H1 = np.zeros_like(self.G)
  59. self.H2 = np.zeros_like(self.G)
  60. for _i in xrange(128):
  61. a = self.preprocess(rnd_warp(img))
  62. A = cv.dft(a, flags=cv.DFT_COMPLEX_OUTPUT)
  63. self.H1 += cv.mulSpectrums(self.G, A, 0, conjB=True)
  64. self.H2 += cv.mulSpectrums( A, A, 0, conjB=True)
  65. self.update_kernel()
  66. self.update(frame)
  67. def update(self, frame, rate = 0.125):
  68. (x, y), (w, h) = self.pos, self.size
  69. self.last_img = img = cv.getRectSubPix(frame, (w, h), (x, y))
  70. img = self.preprocess(img)
  71. self.last_resp, (dx, dy), self.psr = self.correlate(img)
  72. self.good = self.psr > 8.0
  73. if not self.good:
  74. return
  75. self.pos = x+dx, y+dy
  76. self.last_img = img = cv.getRectSubPix(frame, (w, h), self.pos)
  77. img = self.preprocess(img)
  78. A = cv.dft(img, flags=cv.DFT_COMPLEX_OUTPUT)
  79. H1 = cv.mulSpectrums(self.G, A, 0, conjB=True)
  80. H2 = cv.mulSpectrums( A, A, 0, conjB=True)
  81. self.H1 = self.H1 * (1.0-rate) + H1 * rate
  82. self.H2 = self.H2 * (1.0-rate) + H2 * rate
  83. self.update_kernel()
  84. @property
  85. def state_vis(self):
  86. f = cv.idft(self.H, flags=cv.DFT_SCALE | cv.DFT_REAL_OUTPUT )
  87. h, w = f.shape
  88. f = np.roll(f, -h//2, 0)
  89. f = np.roll(f, -w//2, 1)
  90. kernel = np.uint8( (f-f.min()) / f.ptp()*255 )
  91. resp = self.last_resp
  92. resp = np.uint8(np.clip(resp/resp.max(), 0, 1)*255)
  93. vis = np.hstack([self.last_img, kernel, resp])
  94. return vis
  95. def draw_state(self, vis):
  96. (x, y), (w, h) = self.pos, self.size
  97. x1, y1, x2, y2 = int(x-0.5*w), int(y-0.5*h), int(x+0.5*w), int(y+0.5*h)
  98. cv.rectangle(vis, (x1, y1), (x2, y2), (0, 0, 255))
  99. if self.good:
  100. cv.circle(vis, (int(x), int(y)), 2, (0, 0, 255), -1)
  101. else:
  102. cv.line(vis, (x1, y1), (x2, y2), (0, 0, 255))
  103. cv.line(vis, (x2, y1), (x1, y2), (0, 0, 255))
  104. draw_str(vis, (x1, y2+16), 'PSR: %.2f' % self.psr)
  105. def preprocess(self, img):
  106. img = np.log(np.float32(img)+1.0)
  107. img = (img-img.mean()) / (img.std()+eps)
  108. return img*self.win
  109. def correlate(self, img):
  110. C = cv.mulSpectrums(cv.dft(img, flags=cv.DFT_COMPLEX_OUTPUT), self.H, 0, conjB=True)
  111. resp = cv.idft(C, flags=cv.DFT_SCALE | cv.DFT_REAL_OUTPUT)
  112. h, w = resp.shape
  113. _, mval, _, (mx, my) = cv.minMaxLoc(resp)
  114. side_resp = resp.copy()
  115. cv.rectangle(side_resp, (mx-5, my-5), (mx+5, my+5), 0, -1)
  116. smean, sstd = side_resp.mean(), side_resp.std()
  117. psr = (mval-smean) / (sstd+eps)
  118. return resp, (mx-w//2, my-h//2), psr
  119. def update_kernel(self):
  120. self.H = divSpec(self.H1, self.H2)
  121. self.H[...,1] *= -1
  122. class App:
  123. def __init__(self, video_src, paused = False):
  124. self.cap = video.create_capture(video_src)
  125. _, self.frame = self.cap.read()
  126. cv.imshow('frame', self.frame)
  127. self.rect_sel = RectSelector('frame', self.onrect)
  128. self.trackers = []
  129. self.paused = paused
  130. def onrect(self, rect):
  131. frame_gray = cv.cvtColor(self.frame, cv.COLOR_BGR2GRAY)
  132. tracker = MOSSE(frame_gray, rect)
  133. self.trackers.append(tracker)
  134. def run(self):
  135. while True:
  136. if not self.paused:
  137. ret, self.frame = self.cap.read()
  138. if not ret:
  139. break
  140. frame_gray = cv.cvtColor(self.frame, cv.COLOR_BGR2GRAY)
  141. for tracker in self.trackers:
  142. tracker.update(frame_gray)
  143. vis = self.frame.copy()
  144. for tracker in self.trackers:
  145. tracker.draw_state(vis)
  146. if len(self.trackers) > 0:
  147. cv.imshow('tracker state', self.trackers[-1].state_vis)
  148. self.rect_sel.draw(vis)
  149. cv.imshow('frame', vis)
  150. ch = cv.waitKey(10)
  151. if ch == 27:
  152. break
  153. if ch == ord(' '):
  154. self.paused = not self.paused
  155. if ch == ord('c'):
  156. self.trackers = []
  157. if __name__ == '__main__':
  158. print (__doc__)
  159. import sys, getopt
  160. opts, args = getopt.getopt(sys.argv[1:], '', ['pause'])
  161. opts = dict(opts)
  162. try:
  163. video_src = args[0]
  164. except:
  165. video_src = '0'
  166. App(video_src, paused = '--pause' in opts).run()