scene_text_spotting.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #include <iostream>
  2. #include <fstream>
  3. #include <opencv2/imgproc.hpp>
  4. #include <opencv2/highgui.hpp>
  5. #include <opencv2/dnn/dnn.hpp>
  6. using namespace cv;
  7. using namespace cv::dnn;
  8. std::string keys =
  9. "{ help h | | Print help message. }"
  10. "{ inputImage i | | Path to an input image. Skip this argument to capture frames from a camera. }"
  11. "{ detModelPath dmp | | Path to a binary .onnx model for detection. "
  12. "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"
  13. "{ recModelPath rmp | | Path to a binary .onnx model for recognition. "
  14. "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"
  15. "{ inputHeight ih |736| image height of the model input. It should be multiple by 32.}"
  16. "{ inputWidth iw |736| image width of the model input. It should be multiple by 32.}"
  17. "{ RGBInput rgb |0| 0: imread with flags=IMREAD_GRAYSCALE; 1: imread with flags=IMREAD_COLOR. }"
  18. "{ binaryThreshold bt |0.3| Confidence threshold of the binary map. }"
  19. "{ polygonThreshold pt |0.5| Confidence threshold of polygons. }"
  20. "{ maxCandidate max |200| Max candidates of polygons. }"
  21. "{ unclipRatio ratio |2.0| unclip ratio. }"
  22. "{ vocabularyPath vp | alphabet_36.txt | Path to benchmarks for evaluation. "
  23. "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}";
  24. void fourPointsTransform(const Mat& frame, const Point2f vertices[], Mat& result);
  25. bool sortPts(const Point& p1, const Point& p2);
  26. int main(int argc, char** argv)
  27. {
  28. // Parse arguments
  29. CommandLineParser parser(argc, argv, keys);
  30. parser.about("Use this script to run an end-to-end inference sample of textDetectionModel and textRecognitionModel APIs\n"
  31. "Use -h for more information");
  32. if (argc == 1 || parser.has("help"))
  33. {
  34. parser.printMessage();
  35. return 0;
  36. }
  37. float binThresh = parser.get<float>("binaryThreshold");
  38. float polyThresh = parser.get<float>("polygonThreshold");
  39. uint maxCandidates = parser.get<uint>("maxCandidate");
  40. String detModelPath = parser.get<String>("detModelPath");
  41. String recModelPath = parser.get<String>("recModelPath");
  42. String vocPath = parser.get<String>("vocabularyPath");
  43. double unclipRatio = parser.get<double>("unclipRatio");
  44. int height = parser.get<int>("inputHeight");
  45. int width = parser.get<int>("inputWidth");
  46. int imreadRGB = parser.get<int>("RGBInput");
  47. if (!parser.check())
  48. {
  49. parser.printErrors();
  50. return 1;
  51. }
  52. // Load networks
  53. CV_Assert(!detModelPath.empty());
  54. TextDetectionModel_DB detector(detModelPath);
  55. detector.setBinaryThreshold(binThresh)
  56. .setPolygonThreshold(polyThresh)
  57. .setUnclipRatio(unclipRatio)
  58. .setMaxCandidates(maxCandidates);
  59. CV_Assert(!recModelPath.empty());
  60. TextRecognitionModel recognizer(recModelPath);
  61. // Load vocabulary
  62. CV_Assert(!vocPath.empty());
  63. std::ifstream vocFile;
  64. vocFile.open(samples::findFile(vocPath));
  65. CV_Assert(vocFile.is_open());
  66. String vocLine;
  67. std::vector<String> vocabulary;
  68. while (std::getline(vocFile, vocLine)) {
  69. vocabulary.push_back(vocLine);
  70. }
  71. recognizer.setVocabulary(vocabulary);
  72. recognizer.setDecodeType("CTC-greedy");
  73. // Parameters for Detection
  74. double detScale = 1.0 / 255.0;
  75. Size detInputSize = Size(width, height);
  76. Scalar detMean = Scalar(122.67891434, 116.66876762, 104.00698793);
  77. detector.setInputParams(detScale, detInputSize, detMean);
  78. // Parameters for Recognition
  79. double recScale = 1.0 / 127.5;
  80. Scalar recMean = Scalar(127.5);
  81. Size recInputSize = Size(100, 32);
  82. recognizer.setInputParams(recScale, recInputSize, recMean);
  83. // Create a window
  84. static const std::string winName = "Text_Spotting";
  85. // Input data
  86. Mat frame = imread(samples::findFile(parser.get<String>("inputImage")));
  87. std::cout << frame.size << std::endl;
  88. // Inference
  89. std::vector< std::vector<Point> > detResults;
  90. detector.detect(frame, detResults);
  91. if (detResults.size() > 0) {
  92. // Text Recognition
  93. Mat recInput;
  94. if (!imreadRGB) {
  95. cvtColor(frame, recInput, cv::COLOR_BGR2GRAY);
  96. } else {
  97. recInput = frame;
  98. }
  99. std::vector< std::vector<Point> > contours;
  100. for (uint i = 0; i < detResults.size(); i++)
  101. {
  102. const auto& quadrangle = detResults[i];
  103. CV_CheckEQ(quadrangle.size(), (size_t)4, "");
  104. contours.emplace_back(quadrangle);
  105. std::vector<Point2f> quadrangle_2f;
  106. for (int j = 0; j < 4; j++)
  107. quadrangle_2f.emplace_back(quadrangle[j]);
  108. // Transform and Crop
  109. Mat cropped;
  110. fourPointsTransform(recInput, &quadrangle_2f[0], cropped);
  111. std::string recognitionResult = recognizer.recognize(cropped);
  112. std::cout << i << ": '" << recognitionResult << "'" << std::endl;
  113. putText(frame, recognitionResult, quadrangle[3], FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 0, 255), 2);
  114. }
  115. polylines(frame, contours, true, Scalar(0, 255, 0), 2);
  116. } else {
  117. std::cout << "No Text Detected." << std::endl;
  118. }
  119. imshow(winName, frame);
  120. waitKey();
  121. return 0;
  122. }
  123. void fourPointsTransform(const Mat& frame, const Point2f vertices[], Mat& result)
  124. {
  125. const Size outputSize = Size(100, 32);
  126. Point2f targetVertices[4] = {
  127. Point(0, outputSize.height - 1),
  128. Point(0, 0),
  129. Point(outputSize.width - 1, 0),
  130. Point(outputSize.width - 1, outputSize.height - 1)
  131. };
  132. Mat rotationMatrix = getPerspectiveTransform(vertices, targetVertices);
  133. warpPerspective(frame, result, rotationMatrix, outputSize);
  134. #if 0
  135. imshow("roi", result);
  136. waitKey();
  137. #endif
  138. }
  139. bool sortPts(const Point& p1, const Point& p2)
  140. {
  141. return p1.x < p2.x;
  142. }