csrt.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #include <opencv2/core/utility.hpp>
  2. #include <opencv2/tracking.hpp>
  3. #include <opencv2/videoio.hpp>
  4. #include <opencv2/highgui.hpp>
  5. #include <iostream>
  6. #include <cstring>
  7. #include <fstream>
  8. #include "samples_utility.hpp"
  9. using namespace std;
  10. using namespace cv;
  11. int main(int argc, char** argv)
  12. {
  13. // show help
  14. if (argc<2) {
  15. cout <<
  16. " Usage: example_tracking_csrt <video_name>\n"
  17. " examples:\n"
  18. " example_tracking_csrt Bolt/img/%04.jpg\n"
  19. " example_tracking_csrt Bolt/img/%04.jpg Bolt/grouondtruth.txt\n"
  20. " example_tracking_csrt faceocc2.webm\n"
  21. << endl;
  22. return 0;
  23. }
  24. // create the tracker
  25. Ptr<TrackerCSRT> tracker = TrackerCSRT::create();
  26. // const char* param_file_path = "/home/amuhic/Workspace/3_dip/params.yml";
  27. // FileStorage fs(params_file_path, FileStorage::WRITE);
  28. // tracker->write(fs);
  29. // FileStorage fs(param_file_path, FileStorage::READ);
  30. // tracker->read( fs.root());
  31. // set input video
  32. std::string video = argv[1];
  33. VideoCapture cap(video);
  34. // and read first frame
  35. Mat frame;
  36. cap >> frame;
  37. // target bounding box
  38. Rect roi;
  39. if (argc > 2) {
  40. // read first line of ground-truth file
  41. std::string groundtruthPath = argv[2];
  42. std::ifstream gtIfstream(groundtruthPath.c_str());
  43. std::string gtLine;
  44. getline(gtIfstream, gtLine);
  45. gtIfstream.close();
  46. // parse the line by elements
  47. std::stringstream gtStream(gtLine);
  48. std::string element;
  49. std::vector<int> elements;
  50. while (std::getline(gtStream, element, ','))
  51. {
  52. elements.push_back(cvRound(std::atof(element.c_str())));
  53. }
  54. if (elements.size() == 4) {
  55. // ground-truth is rectangle
  56. roi = cv::Rect(elements[0], elements[1], elements[2], elements[3]);
  57. }
  58. else if (elements.size() == 8) {
  59. // ground-truth is polygon
  60. int xMin = cvRound(min(elements[0], min(elements[2], min(elements[4], elements[6]))));
  61. int yMin = cvRound(min(elements[1], min(elements[3], min(elements[5], elements[7]))));
  62. int xMax = cvRound(max(elements[0], max(elements[2], max(elements[4], elements[6]))));
  63. int yMax = cvRound(max(elements[1], max(elements[3], max(elements[5], elements[7]))));
  64. roi = cv::Rect(xMin, yMin, xMax - xMin, yMax - yMin);
  65. // create mask from polygon and set it to the tracker
  66. cv::Rect aaRect = cv::Rect(xMin, yMin, xMax - xMin, yMax - yMin);
  67. cout << aaRect.size() << endl;
  68. Mat mask = Mat::zeros(aaRect.size(), CV_8UC1);
  69. const int n = 4;
  70. std::vector<cv::Point> poly_points(n);
  71. //Translate x and y to rects start position
  72. int sx = aaRect.x;
  73. int sy = aaRect.y;
  74. for (int i = 0; i < n; ++i) {
  75. poly_points[i] = Point(elements[2 * i] - sx, elements[2 * i + 1] - sy);
  76. }
  77. cv::fillConvexPoly(mask, poly_points, Scalar(1.0), 8);
  78. mask.convertTo(mask, CV_32FC1);
  79. tracker->setInitialMask(mask);
  80. }
  81. else {
  82. std::cout << "Number of ground-truth elements is not 4 or 8." << std::endl;
  83. }
  84. }
  85. else {
  86. // second argument is not given - user selects target
  87. roi = selectROI("tracker", frame, true, false);
  88. }
  89. //quit if ROI was not selected
  90. if (roi.width == 0 || roi.height == 0)
  91. return 0;
  92. // initialize the tracker
  93. int64 t1 = cv::getTickCount();
  94. tracker->init(frame, roi);
  95. int64 t2 = cv::getTickCount();
  96. int64 tick_counter = t2 - t1;
  97. // do the tracking
  98. printf("Start the tracking process, press ESC to quit.\n");
  99. int frame_idx = 1;
  100. for (;;) {
  101. // get frame from the video
  102. cap >> frame;
  103. // stop the program if no more images
  104. if (frame.rows == 0 || frame.cols == 0)
  105. break;
  106. // update the tracking result
  107. t1 = cv::getTickCount();
  108. bool isfound = tracker->update(frame, roi);
  109. t2 = cv::getTickCount();
  110. tick_counter += t2 - t1;
  111. frame_idx++;
  112. if (!isfound) {
  113. cout << "The target has been lost...\n";
  114. waitKey(0);
  115. return 0;
  116. }
  117. // draw the tracked object and show the image
  118. rectangle(frame, roi, Scalar(255, 0, 0), 2, 1);
  119. imshow("tracker", frame);
  120. //quit on ESC button
  121. if (waitKey(1) == 27)break;
  122. }
  123. cout << "Elapsed sec: " << static_cast<double>(tick_counter) / cv::getTickFrequency() << endl;
  124. cout << "FPS: " << ((double)(frame_idx)) / (static_cast<double>(tick_counter) / cv::getTickFrequency()) << endl;
  125. }