scene_text_recognition.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #include <iostream>
  2. #include <fstream>
  3. #include <opencv2/imgproc.hpp>
  4. #include <opencv2/highgui.hpp>
  5. #include <opencv2/dnn/dnn.hpp>
  6. using namespace cv;
  7. using namespace cv::dnn;
  8. String keys =
  9. "{ help h | | Print help message. }"
  10. "{ inputImage i | | Path to an input image. Skip this argument to capture frames from a camera. }"
  11. "{ modelPath mp | | Path to a binary .onnx file contains trained CRNN text recognition model. "
  12. "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"
  13. "{ RGBInput rgb |0| 0: imread with flags=IMREAD_GRAYSCALE; 1: imread with flags=IMREAD_COLOR. }"
  14. "{ evaluate e |false| false: predict with input images; true: evaluate on benchmarks. }"
  15. "{ evalDataPath edp | | Path to benchmarks for evaluation. "
  16. "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"
  17. "{ vocabularyPath vp | alphabet_36.txt | Path to recognition vocabulary. "
  18. "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}";
  19. String convertForEval(String &input);
  20. int main(int argc, char** argv)
  21. {
  22. // Parse arguments
  23. CommandLineParser parser(argc, argv, keys);
  24. parser.about("Use this script to run the PyTorch implementation of "
  25. "An End-to-End Trainable Neural Network for Image-based SequenceRecognition and Its Application to Scene Text Recognition "
  26. "(https://arxiv.org/abs/1507.05717)");
  27. if (argc == 1 || parser.has("help"))
  28. {
  29. parser.printMessage();
  30. return 0;
  31. }
  32. String modelPath = parser.get<String>("modelPath");
  33. String vocPath = parser.get<String>("vocabularyPath");
  34. int imreadRGB = parser.get<int>("RGBInput");
  35. if (!parser.check())
  36. {
  37. parser.printErrors();
  38. return 1;
  39. }
  40. // Load the network
  41. CV_Assert(!modelPath.empty());
  42. TextRecognitionModel recognizer(modelPath);
  43. // Load vocabulary
  44. CV_Assert(!vocPath.empty());
  45. std::ifstream vocFile;
  46. vocFile.open(samples::findFile(vocPath));
  47. CV_Assert(vocFile.is_open());
  48. String vocLine;
  49. std::vector<String> vocabulary;
  50. while (std::getline(vocFile, vocLine)) {
  51. vocabulary.push_back(vocLine);
  52. }
  53. recognizer.setVocabulary(vocabulary);
  54. recognizer.setDecodeType("CTC-greedy");
  55. // Set parameters
  56. double scale = 1.0 / 127.5;
  57. Scalar mean = Scalar(127.5, 127.5, 127.5);
  58. Size inputSize = Size(100, 32);
  59. recognizer.setInputParams(scale, inputSize, mean);
  60. if (parser.get<bool>("evaluate"))
  61. {
  62. // For evaluation
  63. String evalDataPath = parser.get<String>("evalDataPath");
  64. CV_Assert(!evalDataPath.empty());
  65. String gtPath = evalDataPath + "/test_gts.txt";
  66. std::ifstream evalGts;
  67. evalGts.open(gtPath);
  68. CV_Assert(evalGts.is_open());
  69. String gtLine;
  70. int cntRight=0, cntAll=0;
  71. TickMeter timer;
  72. timer.reset();
  73. while (std::getline(evalGts, gtLine)) {
  74. size_t splitLoc = gtLine.find_first_of(' ');
  75. String imgPath = evalDataPath + '/' + gtLine.substr(0, splitLoc);
  76. String gt = gtLine.substr(splitLoc+1);
  77. // Inference
  78. Mat frame = imread(samples::findFile(imgPath), imreadRGB);
  79. CV_Assert(!frame.empty());
  80. timer.start();
  81. std::string recognitionResult = recognizer.recognize(frame);
  82. timer.stop();
  83. if (gt == convertForEval(recognitionResult))
  84. cntRight++;
  85. cntAll++;
  86. }
  87. std::cout << "Accuracy(%): " << (double)(cntRight) / (double)(cntAll) << std::endl;
  88. std::cout << "Average Inference Time(ms): " << timer.getTimeMilli() / (double)(cntAll) << std::endl;
  89. }
  90. else
  91. {
  92. // Create a window
  93. static const std::string winName = "Input Cropped Image";
  94. // Open an image file
  95. CV_Assert(parser.has("inputImage"));
  96. Mat frame = imread(samples::findFile(parser.get<String>("inputImage")), imreadRGB);
  97. CV_Assert(!frame.empty());
  98. // Recognition
  99. std::string recognitionResult = recognizer.recognize(frame);
  100. imshow(winName, frame);
  101. std::cout << "Predition: '" << recognitionResult << "'" << std::endl;
  102. waitKey();
  103. }
  104. return 0;
  105. }
  106. // Convert the predictions to lower case, and remove other characters.
  107. // Only for Evaluation
  108. String convertForEval(String & input)
  109. {
  110. String output;
  111. for (uint i = 0; i < input.length(); i++){
  112. char ch = input[i];
  113. if ((int)ch >= 97 && (int)ch <= 122) {
  114. output.push_back(ch);
  115. } else if ((int)ch >= 65 && (int)ch <= 90) {
  116. output.push_back((char)(ch + 32));
  117. } else {
  118. continue;
  119. }
  120. }
  121. return output;
  122. }