scene_text_detection.cpp 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #include <iostream>
  2. #include <fstream>
  3. #include <opencv2/imgproc.hpp>
  4. #include <opencv2/highgui.hpp>
  5. #include <opencv2/dnn/dnn.hpp>
  6. using namespace cv;
  7. using namespace cv::dnn;
  8. std::string keys =
  9. "{ help h | | Print help message. }"
  10. "{ inputImage i | | Path to an input image. Skip this argument to capture frames from a camera. }"
  11. "{ modelPath mp | | Path to a binary .onnx file contains trained DB detector model. "
  12. "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"
  13. "{ inputHeight ih |736| image height of the model input. It should be multiple by 32.}"
  14. "{ inputWidth iw |736| image width of the model input. It should be multiple by 32.}"
  15. "{ binaryThreshold bt |0.3| Confidence threshold of the binary map. }"
  16. "{ polygonThreshold pt |0.5| Confidence threshold of polygons. }"
  17. "{ maxCandidate max |200| Max candidates of polygons. }"
  18. "{ unclipRatio ratio |2.0| unclip ratio. }"
  19. "{ evaluate e |false| false: predict with input images; true: evaluate on benchmarks. }"
  20. "{ evalDataPath edp | | Path to benchmarks for evaluation. "
  21. "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}";
  22. static
  23. void split(const std::string& s, char delimiter, std::vector<std::string>& elems)
  24. {
  25. elems.clear();
  26. size_t prev_pos = 0;
  27. size_t pos = 0;
  28. while ((pos = s.find(delimiter, prev_pos)) != std::string::npos)
  29. {
  30. elems.emplace_back(s.substr(prev_pos, pos - prev_pos));
  31. prev_pos = pos + 1;
  32. }
  33. if (prev_pos < s.size())
  34. elems.emplace_back(s.substr(prev_pos, s.size() - prev_pos));
  35. }
  36. int main(int argc, char** argv)
  37. {
  38. // Parse arguments
  39. CommandLineParser parser(argc, argv, keys);
  40. parser.about("Use this script to run the official PyTorch implementation (https://github.com/MhLiao/DB) of "
  41. "Real-time Scene Text Detection with Differentiable Binarization (https://arxiv.org/abs/1911.08947)\n"
  42. "The current version of this script is a variant of the original network without deformable convolution");
  43. if (argc == 1 || parser.has("help"))
  44. {
  45. parser.printMessage();
  46. return 0;
  47. }
  48. float binThresh = parser.get<float>("binaryThreshold");
  49. float polyThresh = parser.get<float>("polygonThreshold");
  50. uint maxCandidates = parser.get<uint>("maxCandidate");
  51. String modelPath = parser.get<String>("modelPath");
  52. double unclipRatio = parser.get<double>("unclipRatio");
  53. int height = parser.get<int>("inputHeight");
  54. int width = parser.get<int>("inputWidth");
  55. if (!parser.check())
  56. {
  57. parser.printErrors();
  58. return 1;
  59. }
  60. // Load the network
  61. CV_Assert(!modelPath.empty());
  62. TextDetectionModel_DB detector(modelPath);
  63. detector.setBinaryThreshold(binThresh)
  64. .setPolygonThreshold(polyThresh)
  65. .setUnclipRatio(unclipRatio)
  66. .setMaxCandidates(maxCandidates);
  67. double scale = 1.0 / 255.0;
  68. Size inputSize = Size(width, height);
  69. Scalar mean = Scalar(122.67891434, 116.66876762, 104.00698793);
  70. detector.setInputParams(scale, inputSize, mean);
  71. // Create a window
  72. static const std::string winName = "TextDetectionModel";
  73. if (parser.get<bool>("evaluate")) {
  74. // for evaluation
  75. String evalDataPath = parser.get<String>("evalDataPath");
  76. CV_Assert(!evalDataPath.empty());
  77. String testListPath = evalDataPath + "/test_list.txt";
  78. std::ifstream testList;
  79. testList.open(testListPath);
  80. CV_Assert(testList.is_open());
  81. // Create a window for showing groundtruth
  82. static const std::string winNameGT = "GT";
  83. String testImgPath;
  84. while (std::getline(testList, testImgPath)) {
  85. String imgPath = evalDataPath + "/test_images/" + testImgPath;
  86. std::cout << "Image Path: " << imgPath << std::endl;
  87. Mat frame = imread(samples::findFile(imgPath), IMREAD_COLOR);
  88. CV_Assert(!frame.empty());
  89. Mat src = frame.clone();
  90. // Inference
  91. std::vector<std::vector<Point>> results;
  92. detector.detect(frame, results);
  93. polylines(frame, results, true, Scalar(0, 255, 0), 2);
  94. imshow(winName, frame);
  95. // load groundtruth
  96. String imgName = testImgPath.substr(0, testImgPath.length() - 4);
  97. String gtPath = evalDataPath + "/test_gts/" + imgName + ".txt";
  98. // std::cout << gtPath << std::endl;
  99. std::ifstream gtFile;
  100. gtFile.open(gtPath);
  101. CV_Assert(gtFile.is_open());
  102. std::vector<std::vector<Point>> gts;
  103. String gtLine;
  104. while (std::getline(gtFile, gtLine)) {
  105. size_t splitLoc = gtLine.find_last_of(',');
  106. String text = gtLine.substr(splitLoc+1);
  107. if ( text == "###\r" || text == "1") {
  108. // ignore difficult instances
  109. continue;
  110. }
  111. gtLine = gtLine.substr(0, splitLoc);
  112. std::vector<std::string> v;
  113. split(gtLine, ',', v);
  114. std::vector<int> loc;
  115. std::vector<Point> pts;
  116. for (auto && s : v) {
  117. loc.push_back(atoi(s.c_str()));
  118. }
  119. for (size_t i = 0; i < loc.size() / 2; i++) {
  120. pts.push_back(Point(loc[2 * i], loc[2 * i + 1]));
  121. }
  122. gts.push_back(pts);
  123. }
  124. polylines(src, gts, true, Scalar(0, 255, 0), 2);
  125. imshow(winNameGT, src);
  126. waitKey();
  127. }
  128. } else {
  129. // Open an image file
  130. CV_Assert(parser.has("inputImage"));
  131. Mat frame = imread(samples::findFile(parser.get<String>("inputImage")));
  132. CV_Assert(!frame.empty());
  133. // Detect
  134. std::vector<std::vector<Point>> results;
  135. detector.detect(frame, results);
  136. polylines(frame, results, true, Scalar(0, 255, 0), 2);
  137. imshow(winName, frame);
  138. waitKey();
  139. }
  140. return 0;
  141. }