benchmark.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. #include "opencv2/core/utility.hpp"
  2. #include "opencv2/highgui.hpp"
  3. #include "opencv2/tracking.hpp"
  4. #include "opencv2/videoio.hpp"
  5. #include "opencv2/plot.hpp"
  6. #include "samples_utility.hpp"
  7. #include <fstream>
  8. #include <iomanip>
  9. #include <iostream>
  10. using namespace std;
  11. using namespace cv;
  12. // TODO: do normalization ala Kalal's assessment protocol for TLD
  13. static const Scalar gtColor = Scalar(0, 255, 0);
  14. static Scalar getNextColor()
  15. {
  16. const int num = 6;
  17. static Scalar colors[num] = {Scalar(160, 0, 0), Scalar(0, 0, 160), Scalar(0, 160, 160),
  18. Scalar(160, 160, 0), Scalar(160, 0, 160), Scalar(20, 50, 160)};
  19. static int id = 0;
  20. return colors[id < num ? id++ : num - 1];
  21. }
  22. inline vector<Rect2d> readGT(const string &filename, const string &omitname)
  23. {
  24. vector<Rect2d> res;
  25. {
  26. ifstream input(filename.c_str());
  27. if (!input.is_open())
  28. CV_Error(Error::StsError, "Failed to open file");
  29. while (input)
  30. {
  31. Rect2d one;
  32. input >> one.x;
  33. input.ignore(numeric_limits<std::streamsize>::max(), ',');
  34. input >> one.y;
  35. input.ignore(numeric_limits<std::streamsize>::max(), ',');
  36. input >> one.width;
  37. input.ignore(numeric_limits<std::streamsize>::max(), ',');
  38. input >> one.height;
  39. input.ignore(numeric_limits<std::streamsize>::max(), '\n');
  40. if (input.good())
  41. res.push_back(one);
  42. }
  43. }
  44. if (!omitname.empty())
  45. {
  46. ifstream input(omitname.c_str());
  47. if (!input.is_open())
  48. CV_Error(Error::StsError, "Failed to open file");
  49. while (input)
  50. {
  51. unsigned int a = 0, b = 0;
  52. input >> a >> b;
  53. input.ignore(numeric_limits<std::streamsize>::max(), '\n');
  54. if (a > 0 && b > 0 && a < res.size() && b < res.size())
  55. {
  56. if (a > b)
  57. swap(a, b);
  58. for (vector<Rect2d>::iterator i = res.begin() + a; i != res.begin() + b; ++i)
  59. {
  60. *i = Rect2d();
  61. }
  62. }
  63. }
  64. }
  65. return res;
  66. }
  67. inline bool isGoodBox(const Rect2d &box) { return box.width > 0. && box.height > 0.; }
  68. const int LTRC_COUNT = 100;
  69. struct AlgoWrap
  70. {
  71. AlgoWrap(const string &name_)
  72. : lastState(NotFound), name(name_), color(getNextColor()),
  73. numTotal(0), numResponse(0), numPresent(0), numCorrect_0(0), numCorrect_0_5(0),
  74. timeTotal(0), auc(LTRC_COUNT + 1, 0)
  75. {
  76. tracker = createTrackerByName(name);
  77. }
  78. enum State
  79. {
  80. NotFound,
  81. Overlap_None,
  82. Overlap_0,
  83. Overlap_0_5,
  84. };
  85. Ptr<Tracker> tracker;
  86. bool lastRes;
  87. Rect lastBox;
  88. State lastState;
  89. // visual
  90. string name;
  91. Scalar color;
  92. // results
  93. int numTotal; // frames passed to tracker
  94. int numResponse; // frames where tracker had response
  95. int numPresent; // frames where ground truth result present
  96. int numCorrect_0; // frames where overlap with GT > 0
  97. int numCorrect_0_5; // frames where overlap with GT > 0.5
  98. int64 timeTotal; // ticks
  99. vector<int> auc; // number of frames for each overlap percent
  100. void eval(const Mat &frame, const Rect2d &gtBox, bool isVerbose)
  101. {
  102. // RUN
  103. lastBox = Rect();
  104. int64 frameTime = getTickCount();
  105. lastRes = tracker->update(frame, lastBox);
  106. frameTime = getTickCount() - frameTime;
  107. // RESULTS
  108. double intersectArea = (gtBox & (Rect2d)lastBox).area();
  109. double unionArea = (gtBox | (Rect2d)lastBox).area();
  110. numTotal++;
  111. numResponse += (lastRes && isGoodBox(lastBox)) ? 1 : 0;
  112. numPresent += isGoodBox(gtBox) ? 1 : 0;
  113. double overlap = unionArea > 0. ? intersectArea / unionArea : 0.;
  114. numCorrect_0 += overlap > 0. ? 1 : 0;
  115. numCorrect_0_5 += overlap > 0.5 ? 1 : 0;
  116. auc[std::min(std::max((size_t)(overlap * LTRC_COUNT), (size_t)0), (size_t)LTRC_COUNT)]++;
  117. timeTotal += frameTime;
  118. if (isVerbose)
  119. cout << name << " - " << overlap << endl;
  120. if (isGoodBox(gtBox) != isGoodBox(lastBox)) lastState = NotFound;
  121. else if (overlap > 0.5) lastState = Overlap_0_5;
  122. else if (overlap > 0.0001) lastState = Overlap_0;
  123. else lastState = Overlap_None;
  124. }
  125. void draw(Mat &image, const Point &textPoint) const
  126. {
  127. if (lastRes)
  128. rectangle(image, lastBox, color, 2, LINE_8);
  129. string suf;
  130. switch (lastState)
  131. {
  132. case AlgoWrap::NotFound: suf = " X"; break;
  133. case AlgoWrap::Overlap_None: suf = " ~"; break;
  134. case AlgoWrap::Overlap_0: suf = " +"; break;
  135. case AlgoWrap::Overlap_0_5: suf = " ++"; break;
  136. }
  137. putText(image, name + suf, textPoint, FONT_HERSHEY_PLAIN, 1, color, 1, LINE_AA);
  138. }
  139. // calculates "lost track ratio" curve - row of values growing from 0 to 1
  140. // number of elements is LTRC_COUNT + 2
  141. Mat getLTRC() const
  142. {
  143. Mat t, res;
  144. Mat(auc).convertTo(t, CV_64F); // integral does not support CV_32S input
  145. integral(t.t(), res, CV_64F); // t is a column of values
  146. return res.row(1) / (double)numTotal;
  147. }
  148. void plotLTRC(Mat &img) const
  149. {
  150. Ptr<plot::Plot2d> p_ = plot::Plot2d::create(getLTRC());
  151. p_->render(img);
  152. }
  153. double calcAUC() const
  154. {
  155. return cv::sum(getLTRC())[0] / (double)LTRC_COUNT;
  156. }
  157. void stat(ostream &out) const
  158. {
  159. out << name << endl;
  160. out << setw(20) << "Overlap > 0 " << setw(20) << (double)numCorrect_0 / numTotal * 100
  161. << "%" << setw(20) << numCorrect_0 << endl;
  162. out << setw(20) << "Overlap > 0.5" << setw(20) << (double)numCorrect_0_5 / numTotal * 100
  163. << "%" << setw(20) << numCorrect_0_5 << endl;
  164. double p = (double)numCorrect_0_5 / numResponse;
  165. double r = (double)numCorrect_0_5 / numPresent;
  166. double f = 2 * p * r / (p + r);
  167. out << setw(20) << "Precision" << setw(20) << p * 100 << "%" << endl;
  168. out << setw(20) << "Recall " << setw(20) << r * 100 << "%" << endl;
  169. out << setw(20) << "f-measure" << setw(20) << f * 100 << "%" << endl;
  170. out << setw(20) << "AUC" << setw(20) << calcAUC() << endl;
  171. double s = (timeTotal / getTickFrequency()) / numTotal;
  172. out << setw(20) << "Performance" << setw(20) << s * 1000 << " ms/frame" << setw(20) << 1 / s
  173. << " fps" << endl;
  174. }
  175. };
  176. inline ostream &operator<<(ostream &out, const AlgoWrap &w) { w.stat(out); return out; }
  177. inline vector<AlgoWrap> initAlgorithms(const string &algList)
  178. {
  179. vector<AlgoWrap> res;
  180. istringstream input(algList);
  181. for (;;)
  182. {
  183. char one[30];
  184. input.getline(one, 30, ',');
  185. if (!input)
  186. break;
  187. cout << " " << one << " - ";
  188. AlgoWrap a(one);
  189. if (a.tracker)
  190. {
  191. res.push_back(a);
  192. cout << "OK";
  193. }
  194. else
  195. {
  196. cout << "FAILED";
  197. }
  198. cout << endl;
  199. }
  200. return res;
  201. }
  202. static const string &window = "Tracking API";
  203. int main(int argc, char **argv)
  204. {
  205. const string keys =
  206. "{help h||show help}"
  207. "{video||video file to process}"
  208. "{gt||ground truth file (each line describes rectangle in format: '<x>,<y>,<w>,<h>')}"
  209. "{start|0|starting frame}"
  210. "{num|0|frame number (0 for all)}"
  211. "{omit||file with omit ranges (each line describes occluded frames: '<start> <end>')}"
  212. "{plot|false|plot LTR curves at the end}"
  213. "{v|false|print each frame info}"
  214. "{@algos||comma-separated algorithm names}";
  215. CommandLineParser p(argc, argv, keys);
  216. if (p.has("help"))
  217. {
  218. p.printMessage();
  219. return 0;
  220. }
  221. int startFrame = p.get<int>("start");
  222. int frameCount = p.get<int>("num");
  223. string videoFile = p.get<string>("video");
  224. string gtFile = p.get<string>("gt");
  225. string omitFile = p.get<string>("omit");
  226. string algList = p.get<string>("@algos");
  227. bool doPlot = p.get<bool>("plot");
  228. bool isVerbose = p.get<bool>("v");
  229. if (!p.check())
  230. {
  231. p.printErrors();
  232. return 0;
  233. }
  234. cout << "Reading GT from " << gtFile << " ... ";
  235. vector<Rect2d> gt = readGT(gtFile, omitFile);
  236. if (gt.empty())
  237. CV_Error(Error::StsError, "Failed to read GT file");
  238. cout << gt.size() << " boxes" << endl;
  239. cout << "Opening video " << videoFile << " ... ";
  240. VideoCapture cap;
  241. cap.open(videoFile);
  242. if (!cap.isOpened())
  243. CV_Error(Error::StsError, "Failed to open video file");
  244. cap.set(CAP_PROP_POS_FRAMES, startFrame);
  245. cout << "at frame " << startFrame << endl;
  246. // INIT
  247. vector<AlgoWrap> algos = initAlgorithms(algList);
  248. Mat frame, image;
  249. cap >> frame;
  250. for (vector<AlgoWrap>::iterator i = algos.begin(); i != algos.end(); ++i)
  251. i->tracker->init(frame, gt[0]);
  252. // DRAW
  253. {
  254. namedWindow(window, WINDOW_AUTOSIZE);
  255. frame.copyTo(image);
  256. rectangle(image, gt[0], gtColor, 2, LINE_8);
  257. imshow(window, image);
  258. }
  259. bool paused = false;
  260. int frameId = 0;
  261. cout << "Hot keys:" << endl << " q - exit" << endl << " p - pause" << endl;
  262. for (;;)
  263. {
  264. if (!paused)
  265. {
  266. cap >> frame;
  267. if (frame.empty())
  268. {
  269. cout << "Done - video end" << endl;
  270. break;
  271. }
  272. frameId++;
  273. if (isVerbose)
  274. cout << endl << "Frame " << frameId << endl;
  275. // EVAL
  276. for (vector<AlgoWrap>::iterator i = algos.begin(); i != algos.end(); ++i)
  277. i->eval(frame, gt[frameId], isVerbose);
  278. // DRAW
  279. {
  280. Point textPoint(1, 16);
  281. frame.copyTo(image);
  282. rectangle(image, gt[frameId], gtColor, 2, LINE_8);
  283. putText(image, "GROUND TRUTH", textPoint, FONT_HERSHEY_PLAIN, 1, gtColor, 1, LINE_AA);
  284. for (vector<AlgoWrap>::iterator i = algos.begin(); i != algos.end(); ++i)
  285. {
  286. textPoint.y += 14;
  287. i->draw(image, textPoint);
  288. }
  289. imshow(window, image);
  290. }
  291. }
  292. char c = (char)waitKey(1);
  293. if (c == 'q')
  294. {
  295. cout << "Done - manual exit" << endl;
  296. break;
  297. }
  298. else if (c == 'p')
  299. {
  300. paused = !paused;
  301. }
  302. if (frameCount && frameId >= frameCount)
  303. {
  304. cout << "Done - max frame count" << endl;
  305. break;
  306. }
  307. }
  308. // STAT
  309. for (vector<AlgoWrap>::iterator i = algos.begin(); i != algos.end(); ++i)
  310. cout << "==========" << endl << *i << endl;
  311. if (doPlot)
  312. {
  313. Mat img(300, 300, CV_8UC3);
  314. for (vector<AlgoWrap>::iterator i = algos.begin(); i != algos.end(); ++i)
  315. {
  316. i->plotLTRC(img);
  317. imshow("LTR curve for " + i->name, img);
  318. }
  319. waitKey(0);
  320. }
  321. return 0;
  322. }