classification.cpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. #include <fstream>
  2. #include <sstream>
  3. #include <iostream>
  4. #include <opencv2/dnn.hpp>
  5. #include <opencv2/imgproc.hpp>
  6. #include <opencv2/highgui.hpp>
  7. #include "common.hpp"
  8. std::string keys =
  9. "{ help h | | Print help message. }"
  10. "{ @alias | | An alias name of model to extract preprocessing parameters from models.yml file. }"
  11. "{ zoo | models.yml | An optional path to file with preprocessing parameters }"
  12. "{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
  13. "{ initial_width | 0 | Preprocess input image by initial resizing to a specific width.}"
  14. "{ initial_height | 0 | Preprocess input image by initial resizing to a specific height.}"
  15. "{ std | 0.0 0.0 0.0 | Preprocess input image by dividing on a standard deviation.}"
  16. "{ crop | false | Preprocess input image by center cropping.}"
  17. "{ framework f | | Optional name of an origin framework of the model. Detect it automatically if it does not set. }"
  18. "{ needSoftmax | false | Use Softmax to post-process the output of the net.}"
  19. "{ classes | | Optional path to a text file with names of classes. }"
  20. "{ backend | 0 | Choose one of computation backends: "
  21. "0: automatically (by default), "
  22. "1: Halide language (http://halide-lang.org/), "
  23. "2: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
  24. "3: OpenCV implementation, "
  25. "4: VKCOM, "
  26. "5: CUDA, "
  27. "6: WebNN }"
  28. "{ target | 0 | Choose one of target computation devices: "
  29. "0: CPU target (by default), "
  30. "1: OpenCL, "
  31. "2: OpenCL fp16 (half-float precision), "
  32. "3: VPU, "
  33. "4: Vulkan, "
  34. "6: CUDA, "
  35. "7: CUDA fp16 (half-float preprocess) }";
  36. using namespace cv;
  37. using namespace dnn;
  38. std::vector<std::string> classes;
  39. int main(int argc, char** argv)
  40. {
  41. CommandLineParser parser(argc, argv, keys);
  42. const std::string modelName = parser.get<String>("@alias");
  43. const std::string zooFile = parser.get<String>("zoo");
  44. keys += genPreprocArguments(modelName, zooFile);
  45. parser = CommandLineParser(argc, argv, keys);
  46. parser.about("Use this script to run classification deep learning networks using OpenCV.");
  47. if (argc == 1 || parser.has("help"))
  48. {
  49. parser.printMessage();
  50. return 0;
  51. }
  52. int rszWidth = parser.get<int>("initial_width");
  53. int rszHeight = parser.get<int>("initial_height");
  54. float scale = parser.get<float>("scale");
  55. Scalar mean = parser.get<Scalar>("mean");
  56. Scalar std = parser.get<Scalar>("std");
  57. bool swapRB = parser.get<bool>("rgb");
  58. bool crop = parser.get<bool>("crop");
  59. int inpWidth = parser.get<int>("width");
  60. int inpHeight = parser.get<int>("height");
  61. String model = findFile(parser.get<String>("model"));
  62. String config = findFile(parser.get<String>("config"));
  63. String framework = parser.get<String>("framework");
  64. int backendId = parser.get<int>("backend");
  65. int targetId = parser.get<int>("target");
  66. bool needSoftmax = parser.get<bool>("needSoftmax");
  67. std::cout<<"mean: "<<mean<<std::endl;
  68. std::cout<<"std: "<<std<<std::endl;
  69. // Open file with classes names.
  70. if (parser.has("classes"))
  71. {
  72. std::string file = parser.get<String>("classes");
  73. std::ifstream ifs(file.c_str());
  74. if (!ifs.is_open())
  75. CV_Error(Error::StsError, "File " + file + " not found");
  76. std::string line;
  77. while (std::getline(ifs, line))
  78. {
  79. classes.push_back(line);
  80. }
  81. }
  82. if (!parser.check())
  83. {
  84. parser.printErrors();
  85. return 1;
  86. }
  87. CV_Assert(!model.empty());
  88. //! [Read and initialize network]
  89. Net net = readNet(model, config, framework);
  90. net.setPreferableBackend(backendId);
  91. net.setPreferableTarget(targetId);
  92. //! [Read and initialize network]
  93. // Create a window
  94. static const std::string kWinName = "Deep learning image classification in OpenCV";
  95. namedWindow(kWinName, WINDOW_NORMAL);
  96. //! [Open a video file or an image file or a camera stream]
  97. VideoCapture cap;
  98. if (parser.has("input"))
  99. cap.open(parser.get<String>("input"));
  100. else
  101. cap.open(0);
  102. //! [Open a video file or an image file or a camera stream]
  103. // Process frames.
  104. Mat frame, blob;
  105. while (waitKey(1) < 0)
  106. {
  107. cap >> frame;
  108. if (frame.empty())
  109. {
  110. waitKey();
  111. break;
  112. }
  113. if (rszWidth != 0 && rszHeight != 0)
  114. {
  115. resize(frame, frame, Size(rszWidth, rszHeight));
  116. }
  117. //! [Create a 4D blob from a frame]
  118. blobFromImage(frame, blob, scale, Size(inpWidth, inpHeight), mean, swapRB, crop);
  119. // Check std values.
  120. if (std.val[0] != 0.0 && std.val[1] != 0.0 && std.val[2] != 0.0)
  121. {
  122. // Divide blob by std.
  123. divide(blob, std, blob);
  124. }
  125. //! [Create a 4D blob from a frame]
  126. //! [Set input blob]
  127. net.setInput(blob);
  128. //! [Set input blob]
  129. //! [Make forward pass]
  130. // double t_sum = 0.0;
  131. // double t;
  132. int classId;
  133. double confidence;
  134. cv::TickMeter timeRecorder;
  135. timeRecorder.reset();
  136. Mat prob = net.forward();
  137. double t1;
  138. timeRecorder.start();
  139. prob = net.forward();
  140. timeRecorder.stop();
  141. t1 = timeRecorder.getTimeMilli();
  142. timeRecorder.reset();
  143. for(int i = 0; i < 200; i++) {
  144. //! [Make forward pass]
  145. timeRecorder.start();
  146. prob = net.forward();
  147. timeRecorder.stop();
  148. //! [Get a class with a highest score]
  149. Point classIdPoint;
  150. minMaxLoc(prob.reshape(1, 1), 0, &confidence, 0, &classIdPoint);
  151. classId = classIdPoint.x;
  152. //! [Get a class with a highest score]
  153. // Put efficiency information.
  154. // std::vector<double> layersTimes;
  155. // double freq = getTickFrequency() / 1000;
  156. // t = net.getPerfProfile(layersTimes) / freq;
  157. // t_sum += t;
  158. }
  159. if (needSoftmax == true)
  160. {
  161. float maxProb = 0.0;
  162. float sum = 0.0;
  163. Mat softmaxProb;
  164. maxProb = *std::max_element(prob.begin<float>(), prob.end<float>());
  165. cv::exp(prob-maxProb, softmaxProb);
  166. sum = (float)cv::sum(softmaxProb)[0];
  167. softmaxProb /= sum;
  168. Point classIdPoint;
  169. minMaxLoc(softmaxProb.reshape(1, 1), 0, &confidence, 0, &classIdPoint);
  170. classId = classIdPoint.x;
  171. }
  172. std::string label = format("Inference time of 1 round: %.2f ms", t1);
  173. std::string label2 = format("Average time of 200 rounds: %.2f ms", timeRecorder.getTimeMilli()/200);
  174. putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  175. putText(frame, label2, Point(0, 35), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  176. // Print predicted class.
  177. label = format("%s: %.4f", (classes.empty() ? format("Class #%d", classId).c_str() :
  178. classes[classId].c_str()),
  179. confidence);
  180. putText(frame, label, Point(0, 55), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  181. imshow(kWinName, frame);
  182. }
  183. return 0;
  184. }