dasiamrpn_tracker.cpp 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. // DaSiamRPN tracker.
  2. // Original paper: https://arxiv.org/abs/1808.06048
  3. // Link to original repo: https://github.com/foolwood/DaSiamRPN
  4. // Links to onnx models:
  5. // - network: https://www.dropbox.com/s/rr1lk9355vzolqv/dasiamrpn_model.onnx?dl=0
  6. // - kernel_r1: https://www.dropbox.com/s/999cqx5zrfi7w4p/dasiamrpn_kernel_r1.onnx?dl=0
  7. // - kernel_cls1: https://www.dropbox.com/s/qvmtszx5h339a0w/dasiamrpn_kernel_cls1.onnx?dl=0
  8. #include <iostream>
  9. #include <cmath>
  10. #include <opencv2/dnn.hpp>
  11. #include <opencv2/imgproc.hpp>
  12. #include <opencv2/highgui.hpp>
  13. #include <opencv2/video.hpp>
  14. using namespace cv;
  15. using namespace cv::dnn;
  16. const char *keys =
  17. "{ help h | | Print help message }"
  18. "{ input i | | Full path to input video folder, the specific camera index. (empty for camera 0) }"
  19. "{ net | dasiamrpn_model.onnx | Path to onnx model of net}"
  20. "{ kernel_cls1 | dasiamrpn_kernel_cls1.onnx | Path to onnx model of kernel_r1 }"
  21. "{ kernel_r1 | dasiamrpn_kernel_r1.onnx | Path to onnx model of kernel_cls1 }"
  22. "{ backend | 0 | Choose one of computation backends: "
  23. "0: automatically (by default), "
  24. "1: Halide language (http://halide-lang.org/), "
  25. "2: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
  26. "3: OpenCV implementation, "
  27. "4: VKCOM, "
  28. "5: CUDA },"
  29. "{ target | 0 | Choose one of target computation devices: "
  30. "0: CPU target (by default), "
  31. "1: OpenCL, "
  32. "2: OpenCL fp16 (half-float precision), "
  33. "3: VPU, "
  34. "4: Vulkan, "
  35. "6: CUDA, "
  36. "7: CUDA fp16 (half-float preprocess) }"
  37. ;
  38. static
  39. int run(int argc, char** argv)
  40. {
  41. // Parse command line arguments.
  42. CommandLineParser parser(argc, argv, keys);
  43. if (parser.has("help"))
  44. {
  45. parser.printMessage();
  46. return 0;
  47. }
  48. std::string inputName = parser.get<String>("input");
  49. std::string net = parser.get<String>("net");
  50. std::string kernel_cls1 = parser.get<String>("kernel_cls1");
  51. std::string kernel_r1 = parser.get<String>("kernel_r1");
  52. int backend = parser.get<int>("backend");
  53. int target = parser.get<int>("target");
  54. Ptr<TrackerDaSiamRPN> tracker;
  55. try
  56. {
  57. TrackerDaSiamRPN::Params params;
  58. params.model = samples::findFile(net);
  59. params.kernel_cls1 = samples::findFile(kernel_cls1);
  60. params.kernel_r1 = samples::findFile(kernel_r1);
  61. params.backend = backend;
  62. params.target = target;
  63. tracker = TrackerDaSiamRPN::create(params);
  64. }
  65. catch (const cv::Exception& ee)
  66. {
  67. std::cerr << "Exception: " << ee.what() << std::endl;
  68. std::cout << "Can't load the network by using the following files:" << std::endl;
  69. std::cout << "siamRPN : " << net << std::endl;
  70. std::cout << "siamKernelCL1 : " << kernel_cls1 << std::endl;
  71. std::cout << "siamKernelR1 : " << kernel_r1 << std::endl;
  72. return 2;
  73. }
  74. const std::string winName = "DaSiamRPN";
  75. namedWindow(winName, WINDOW_AUTOSIZE);
  76. // Open a video file or an image file or a camera stream.
  77. VideoCapture cap;
  78. if (inputName.empty() || (isdigit(inputName[0]) && inputName.size() == 1))
  79. {
  80. int c = inputName.empty() ? 0 : inputName[0] - '0';
  81. std::cout << "Trying to open camera #" << c << " ..." << std::endl;
  82. if (!cap.open(c))
  83. {
  84. std::cout << "Capture from camera #" << c << " didn't work. Specify -i=<video> parameter to read from video file" << std::endl;
  85. return 2;
  86. }
  87. }
  88. else if (inputName.size())
  89. {
  90. inputName = samples::findFileOrKeep(inputName);
  91. if (!cap.open(inputName))
  92. {
  93. std::cout << "Could not open: " << inputName << std::endl;
  94. return 2;
  95. }
  96. }
  97. // Read the first image.
  98. Mat image;
  99. cap >> image;
  100. if (image.empty())
  101. {
  102. std::cerr << "Can't capture frame!" << std::endl;
  103. return 2;
  104. }
  105. Mat image_select = image.clone();
  106. putText(image_select, "Select initial bounding box you want to track.", Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  107. putText(image_select, "And Press the ENTER key.", Point(0, 35), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  108. Rect selectRect = selectROI(winName, image_select);
  109. std::cout << "ROI=" << selectRect << std::endl;
  110. tracker->init(image, selectRect);
  111. TickMeter tickMeter;
  112. for (int count = 0; ; ++count)
  113. {
  114. cap >> image;
  115. if (image.empty())
  116. {
  117. std::cerr << "Can't capture frame " << count << ". End of video stream?" << std::endl;
  118. break;
  119. }
  120. Rect rect;
  121. tickMeter.start();
  122. bool ok = tracker->update(image, rect);
  123. tickMeter.stop();
  124. float score = tracker->getTrackingScore();
  125. std::cout << "frame " << count <<
  126. ": predicted score=" << score <<
  127. " rect=" << rect <<
  128. " time=" << tickMeter.getTimeMilli() << "ms" <<
  129. std::endl;
  130. Mat render_image = image.clone();
  131. if (ok)
  132. {
  133. rectangle(render_image, rect, Scalar(0, 255, 0), 2);
  134. std::string timeLabel = format("Inference time: %.2f ms", tickMeter.getTimeMilli());
  135. std::string scoreLabel = format("Score: %f", score);
  136. putText(render_image, timeLabel, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  137. putText(render_image, scoreLabel, Point(0, 35), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  138. }
  139. imshow(winName, render_image);
  140. tickMeter.reset();
  141. int c = waitKey(1);
  142. if (c == 27 /*ESC*/)
  143. break;
  144. }
  145. std::cout << "Exit" << std::endl;
  146. return 0;
  147. }
  148. int main(int argc, char **argv)
  149. {
  150. try
  151. {
  152. return run(argc, argv);
  153. }
  154. catch (const std::exception& e)
  155. {
  156. std::cerr << "FATAL: C++ exception: " << e.what() << std::endl;
  157. return 1;
  158. }
  159. }