123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- #include <iostream>
- #include <fstream>
- #include <opencv2/imgproc.hpp>
- #include <opencv2/highgui.hpp>
- #include <opencv2/dnn/dnn.hpp>
- using namespace cv;
- using namespace cv::dnn;
- String keys =
- "{ help h | | Print help message. }"
- "{ inputImage i | | Path to an input image. Skip this argument to capture frames from a camera. }"
- "{ modelPath mp | | Path to a binary .onnx file contains trained CRNN text recognition model. "
- "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"
- "{ RGBInput rgb |0| 0: imread with flags=IMREAD_GRAYSCALE; 1: imread with flags=IMREAD_COLOR. }"
- "{ evaluate e |false| false: predict with input images; true: evaluate on benchmarks. }"
- "{ evalDataPath edp | | Path to benchmarks for evaluation. "
- "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"
- "{ vocabularyPath vp | alphabet_36.txt | Path to recognition vocabulary. "
- "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}";
- String convertForEval(String &input);
- int main(int argc, char** argv)
- {
- // Parse arguments
- CommandLineParser parser(argc, argv, keys);
- parser.about("Use this script to run the PyTorch implementation of "
- "An End-to-End Trainable Neural Network for Image-based SequenceRecognition and Its Application to Scene Text Recognition "
- "(https://arxiv.org/abs/1507.05717)");
- if (argc == 1 || parser.has("help"))
- {
- parser.printMessage();
- return 0;
- }
- String modelPath = parser.get<String>("modelPath");
- String vocPath = parser.get<String>("vocabularyPath");
- int imreadRGB = parser.get<int>("RGBInput");
- if (!parser.check())
- {
- parser.printErrors();
- return 1;
- }
- // Load the network
- CV_Assert(!modelPath.empty());
- TextRecognitionModel recognizer(modelPath);
- // Load vocabulary
- CV_Assert(!vocPath.empty());
- std::ifstream vocFile;
- vocFile.open(samples::findFile(vocPath));
- CV_Assert(vocFile.is_open());
- String vocLine;
- std::vector<String> vocabulary;
- while (std::getline(vocFile, vocLine)) {
- vocabulary.push_back(vocLine);
- }
- recognizer.setVocabulary(vocabulary);
- recognizer.setDecodeType("CTC-greedy");
- // Set parameters
- double scale = 1.0 / 127.5;
- Scalar mean = Scalar(127.5, 127.5, 127.5);
- Size inputSize = Size(100, 32);
- recognizer.setInputParams(scale, inputSize, mean);
- if (parser.get<bool>("evaluate"))
- {
- // For evaluation
- String evalDataPath = parser.get<String>("evalDataPath");
- CV_Assert(!evalDataPath.empty());
- String gtPath = evalDataPath + "/test_gts.txt";
- std::ifstream evalGts;
- evalGts.open(gtPath);
- CV_Assert(evalGts.is_open());
- String gtLine;
- int cntRight=0, cntAll=0;
- TickMeter timer;
- timer.reset();
- while (std::getline(evalGts, gtLine)) {
- size_t splitLoc = gtLine.find_first_of(' ');
- String imgPath = evalDataPath + '/' + gtLine.substr(0, splitLoc);
- String gt = gtLine.substr(splitLoc+1);
- // Inference
- Mat frame = imread(samples::findFile(imgPath), imreadRGB);
- CV_Assert(!frame.empty());
- timer.start();
- std::string recognitionResult = recognizer.recognize(frame);
- timer.stop();
- if (gt == convertForEval(recognitionResult))
- cntRight++;
- cntAll++;
- }
- std::cout << "Accuracy(%): " << (double)(cntRight) / (double)(cntAll) << std::endl;
- std::cout << "Average Inference Time(ms): " << timer.getTimeMilli() / (double)(cntAll) << std::endl;
- }
- else
- {
- // Create a window
- static const std::string winName = "Input Cropped Image";
- // Open an image file
- CV_Assert(parser.has("inputImage"));
- Mat frame = imread(samples::findFile(parser.get<String>("inputImage")), imreadRGB);
- CV_Assert(!frame.empty());
- // Recognition
- std::string recognitionResult = recognizer.recognize(frame);
- imshow(winName, frame);
- std::cout << "Predition: '" << recognitionResult << "'" << std::endl;
- waitKey();
- }
- return 0;
- }
- // Convert the predictions to lower case, and remove other characters.
- // Only for Evaluation
- String convertForEval(String & input)
- {
- String output;
- for (uint i = 0; i < input.length(); i++){
- char ch = input[i];
- if ((int)ch >= 97 && (int)ch <= 122) {
- output.push_back(ch);
- } else if ((int)ch >= 65 && (int)ch <= 90) {
- output.push_back((char)(ch + 32));
- } else {
- continue;
- }
- }
- return output;
- }
|