tracker.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. #!/usr/bin/env python
  2. '''
  3. Tracker demo
  4. For usage download models by following links
  5. For GOTURN:
  6. goturn.prototxt and goturn.caffemodel: https://github.com/opencv/opencv_extra/tree/c4219d5eb3105ed8e634278fad312a1a8d2c182d/testdata/tracking
  7. For DaSiamRPN:
  8. network: https://www.dropbox.com/s/rr1lk9355vzolqv/dasiamrpn_model.onnx?dl=0
  9. kernel_r1: https://www.dropbox.com/s/999cqx5zrfi7w4p/dasiamrpn_kernel_r1.onnx?dl=0
  10. kernel_cls1: https://www.dropbox.com/s/qvmtszx5h339a0w/dasiamrpn_kernel_cls1.onnx?dl=0
  11. USAGE:
  12. tracker.py [-h] [--input INPUT] [--tracker_algo TRACKER_ALGO]
  13. [--goturn GOTURN] [--goturn_model GOTURN_MODEL]
  14. [--dasiamrpn_net DASIAMRPN_NET]
  15. [--dasiamrpn_kernel_r1 DASIAMRPN_KERNEL_R1]
  16. [--dasiamrpn_kernel_cls1 DASIAMRPN_KERNEL_CLS1]
  17. [--dasiamrpn_backend DASIAMRPN_BACKEND]
  18. [--dasiamrpn_target DASIAMRPN_TARGET]
  19. '''
  20. # Python 2/3 compatibility
  21. from __future__ import print_function
  22. import sys
  23. import numpy as np
  24. import cv2 as cv
  25. import argparse
  26. from video import create_capture, presets
  27. class App(object):
  28. def __init__(self, args):
  29. self.args = args
  30. self.trackerAlgorithm = args.tracker_algo
  31. self.tracker = self.createTracker()
  32. def createTracker(self):
  33. if self.trackerAlgorithm == 'mil':
  34. tracker = cv.TrackerMIL_create()
  35. elif self.trackerAlgorithm == 'goturn':
  36. params = cv.TrackerGOTURN_Params()
  37. params.modelTxt = self.args.goturn
  38. params.modelBin = self.args.goturn_model
  39. tracker = cv.TrackerGOTURN_create(params)
  40. elif self.trackerAlgorithm == 'dasiamrpn':
  41. params = cv.TrackerDaSiamRPN_Params()
  42. params.model = self.args.dasiamrpn_net
  43. params.kernel_cls1 = self.args.dasiamrpn_kernel_cls1
  44. params.kernel_r1 = self.args.dasiamrpn_kernel_r1
  45. tracker = cv.TrackerDaSiamRPN_create(params)
  46. else:
  47. sys.exit("Tracker {} is not recognized. Please use one of three available: mil, goturn, dasiamrpn.".format(self.trackerAlgorithm))
  48. return tracker
  49. def initializeTracker(self, image):
  50. while True:
  51. print('==> Select object ROI for tracker ...')
  52. bbox = cv.selectROI('tracking', image)
  53. print('ROI: {}'.format(bbox))
  54. if bbox[2] <= 0 or bbox[3] <= 0:
  55. sys.exit("ROI selection cancelled. Exiting...")
  56. try:
  57. self.tracker.init(image, bbox)
  58. except Exception as e:
  59. print('Unable to initialize tracker with requested bounding box. Is there any object?')
  60. print(e)
  61. print('Try again ...')
  62. continue
  63. return
  64. def run(self):
  65. videoPath = self.args.input
  66. print('Using video: {}'.format(videoPath))
  67. camera = create_capture(cv.samples.findFileOrKeep(videoPath), presets['cube'])
  68. if not camera.isOpened():
  69. sys.exit("Can't open video stream: {}".format(videoPath))
  70. ok, image = camera.read()
  71. if not ok:
  72. sys.exit("Can't read first frame")
  73. assert image is not None
  74. cv.namedWindow('tracking')
  75. self.initializeTracker(image)
  76. print("==> Tracking is started. Press 'SPACE' to re-initialize tracker or 'ESC' for exit...")
  77. while camera.isOpened():
  78. ok, image = camera.read()
  79. if not ok:
  80. print("Can't read frame")
  81. break
  82. ok, newbox = self.tracker.update(image)
  83. #print(ok, newbox)
  84. if ok:
  85. cv.rectangle(image, newbox, (200,0,0))
  86. cv.imshow("tracking", image)
  87. k = cv.waitKey(1)
  88. if k == 32: # SPACE
  89. self.initializeTracker(image)
  90. if k == 27: # ESC
  91. break
  92. print('Done')
  93. if __name__ == '__main__':
  94. print(__doc__)
  95. parser = argparse.ArgumentParser(description="Run tracker")
  96. parser.add_argument("--input", type=str, default="vtest.avi", help="Path to video source")
  97. parser.add_argument("--tracker_algo", type=str, default="mil", help="One of available tracking algorithms: mil, goturn, dasiamrpn")
  98. parser.add_argument("--goturn", type=str, default="goturn.prototxt", help="Path to GOTURN architecture")
  99. parser.add_argument("--goturn_model", type=str, default="goturn.caffemodel", help="Path to GOTERN model")
  100. parser.add_argument("--dasiamrpn_net", type=str, default="dasiamrpn_model.onnx", help="Path to onnx model of DaSiamRPN net")
  101. parser.add_argument("--dasiamrpn_kernel_r1", type=str, default="dasiamrpn_kernel_r1.onnx", help="Path to onnx model of DaSiamRPN kernel_r1")
  102. parser.add_argument("--dasiamrpn_kernel_cls1", type=str, default="dasiamrpn_kernel_cls1.onnx", help="Path to onnx model of DaSiamRPN kernel_cls1")
  103. args = parser.parse_args()
  104. App(args).run()
  105. cv.destroyAllWindows()