segmentation.cpp 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. #include <fstream>
  2. #include <sstream>
  3. #include <opencv2/dnn.hpp>
  4. #include <opencv2/imgproc.hpp>
  5. #include <opencv2/highgui.hpp>
  6. #include "common.hpp"
  7. std::string keys =
  8. "{ help h | | Print help message. }"
  9. "{ @alias | | An alias name of model to extract preprocessing parameters from models.yml file. }"
  10. "{ zoo | models.yml | An optional path to file with preprocessing parameters }"
  11. "{ device | 0 | camera device number. }"
  12. "{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera. }"
  13. "{ framework f | | Optional name of an origin framework of the model. Detect it automatically if it does not set. }"
  14. "{ classes | | Optional path to a text file with names of classes. }"
  15. "{ colors | | Optional path to a text file with colors for an every class. "
  16. "An every color is represented with three values from 0 to 255 in BGR channels order. }"
  17. "{ backend | 0 | Choose one of computation backends: "
  18. "0: automatically (by default), "
  19. "1: Halide language (http://halide-lang.org/), "
  20. "2: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
  21. "3: OpenCV implementation, "
  22. "4: VKCOM, "
  23. "5: CUDA }"
  24. "{ target | 0 | Choose one of target computation devices: "
  25. "0: CPU target (by default), "
  26. "1: OpenCL, "
  27. "2: OpenCL fp16 (half-float precision), "
  28. "3: VPU, "
  29. "4: Vulkan, "
  30. "6: CUDA, "
  31. "7: CUDA fp16 (half-float preprocess) }";
  32. using namespace cv;
  33. using namespace dnn;
  34. std::vector<std::string> classes;
  35. std::vector<Vec3b> colors;
  36. void showLegend();
  37. void colorizeSegmentation(const Mat &score, Mat &segm);
  38. int main(int argc, char** argv)
  39. {
  40. CommandLineParser parser(argc, argv, keys);
  41. const std::string modelName = parser.get<String>("@alias");
  42. const std::string zooFile = parser.get<String>("zoo");
  43. keys += genPreprocArguments(modelName, zooFile);
  44. parser = CommandLineParser(argc, argv, keys);
  45. parser.about("Use this script to run semantic segmentation deep learning networks using OpenCV.");
  46. if (argc == 1 || parser.has("help"))
  47. {
  48. parser.printMessage();
  49. return 0;
  50. }
  51. float scale = parser.get<float>("scale");
  52. Scalar mean = parser.get<Scalar>("mean");
  53. bool swapRB = parser.get<bool>("rgb");
  54. int inpWidth = parser.get<int>("width");
  55. int inpHeight = parser.get<int>("height");
  56. String model = findFile(parser.get<String>("model"));
  57. String config = findFile(parser.get<String>("config"));
  58. String framework = parser.get<String>("framework");
  59. int backendId = parser.get<int>("backend");
  60. int targetId = parser.get<int>("target");
  61. // Open file with classes names.
  62. if (parser.has("classes"))
  63. {
  64. std::string file = parser.get<String>("classes");
  65. std::ifstream ifs(file.c_str());
  66. if (!ifs.is_open())
  67. CV_Error(Error::StsError, "File " + file + " not found");
  68. std::string line;
  69. while (std::getline(ifs, line))
  70. {
  71. classes.push_back(line);
  72. }
  73. }
  74. // Open file with colors.
  75. if (parser.has("colors"))
  76. {
  77. std::string file = parser.get<String>("colors");
  78. std::ifstream ifs(file.c_str());
  79. if (!ifs.is_open())
  80. CV_Error(Error::StsError, "File " + file + " not found");
  81. std::string line;
  82. while (std::getline(ifs, line))
  83. {
  84. std::istringstream colorStr(line.c_str());
  85. Vec3b color;
  86. for (int i = 0; i < 3 && !colorStr.eof(); ++i)
  87. colorStr >> color[i];
  88. colors.push_back(color);
  89. }
  90. }
  91. if (!parser.check())
  92. {
  93. parser.printErrors();
  94. return 1;
  95. }
  96. CV_Assert(!model.empty());
  97. //! [Read and initialize network]
  98. Net net = readNet(model, config, framework);
  99. net.setPreferableBackend(backendId);
  100. net.setPreferableTarget(targetId);
  101. //! [Read and initialize network]
  102. // Create a window
  103. static const std::string kWinName = "Deep learning semantic segmentation in OpenCV";
  104. namedWindow(kWinName, WINDOW_NORMAL);
  105. //! [Open a video file or an image file or a camera stream]
  106. VideoCapture cap;
  107. if (parser.has("input"))
  108. cap.open(parser.get<String>("input"));
  109. else
  110. cap.open(parser.get<int>("device"));
  111. //! [Open a video file or an image file or a camera stream]
  112. // Process frames.
  113. Mat frame, blob;
  114. while (waitKey(1) < 0)
  115. {
  116. cap >> frame;
  117. if (frame.empty())
  118. {
  119. waitKey();
  120. break;
  121. }
  122. //! [Create a 4D blob from a frame]
  123. blobFromImage(frame, blob, scale, Size(inpWidth, inpHeight), mean, swapRB, false);
  124. //! [Create a 4D blob from a frame]
  125. //! [Set input blob]
  126. net.setInput(blob);
  127. //! [Set input blob]
  128. //! [Make forward pass]
  129. Mat score = net.forward();
  130. //! [Make forward pass]
  131. Mat segm;
  132. colorizeSegmentation(score, segm);
  133. resize(segm, segm, frame.size(), 0, 0, INTER_NEAREST);
  134. addWeighted(frame, 0.1, segm, 0.9, 0.0, frame);
  135. // Put efficiency information.
  136. std::vector<double> layersTimes;
  137. double freq = getTickFrequency() / 1000;
  138. double t = net.getPerfProfile(layersTimes) / freq;
  139. std::string label = format("Inference time: %.2f ms", t);
  140. putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  141. imshow(kWinName, frame);
  142. if (!classes.empty())
  143. showLegend();
  144. }
  145. return 0;
  146. }
  147. void colorizeSegmentation(const Mat &score, Mat &segm)
  148. {
  149. const int rows = score.size[2];
  150. const int cols = score.size[3];
  151. const int chns = score.size[1];
  152. if (colors.empty())
  153. {
  154. // Generate colors.
  155. colors.push_back(Vec3b());
  156. for (int i = 1; i < chns; ++i)
  157. {
  158. Vec3b color;
  159. for (int j = 0; j < 3; ++j)
  160. color[j] = (colors[i - 1][j] + rand() % 256) / 2;
  161. colors.push_back(color);
  162. }
  163. }
  164. else if (chns != (int)colors.size())
  165. {
  166. CV_Error(Error::StsError, format("Number of output classes does not match "
  167. "number of colors (%d != %zu)", chns, colors.size()));
  168. }
  169. Mat maxCl = Mat::zeros(rows, cols, CV_8UC1);
  170. Mat maxVal(rows, cols, CV_32FC1, score.data);
  171. for (int ch = 1; ch < chns; ch++)
  172. {
  173. for (int row = 0; row < rows; row++)
  174. {
  175. const float *ptrScore = score.ptr<float>(0, ch, row);
  176. uint8_t *ptrMaxCl = maxCl.ptr<uint8_t>(row);
  177. float *ptrMaxVal = maxVal.ptr<float>(row);
  178. for (int col = 0; col < cols; col++)
  179. {
  180. if (ptrScore[col] > ptrMaxVal[col])
  181. {
  182. ptrMaxVal[col] = ptrScore[col];
  183. ptrMaxCl[col] = (uchar)ch;
  184. }
  185. }
  186. }
  187. }
  188. segm.create(rows, cols, CV_8UC3);
  189. for (int row = 0; row < rows; row++)
  190. {
  191. const uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
  192. Vec3b *ptrSegm = segm.ptr<Vec3b>(row);
  193. for (int col = 0; col < cols; col++)
  194. {
  195. ptrSegm[col] = colors[ptrMaxCl[col]];
  196. }
  197. }
  198. }
  199. void showLegend()
  200. {
  201. static const int kBlockHeight = 30;
  202. static Mat legend;
  203. if (legend.empty())
  204. {
  205. const int numClasses = (int)classes.size();
  206. if ((int)colors.size() != numClasses)
  207. {
  208. CV_Error(Error::StsError, format("Number of output classes does not match "
  209. "number of labels (%zu != %zu)", colors.size(), classes.size()));
  210. }
  211. legend.create(kBlockHeight * numClasses, 200, CV_8UC3);
  212. for (int i = 0; i < numClasses; i++)
  213. {
  214. Mat block = legend.rowRange(i * kBlockHeight, (i + 1) * kBlockHeight);
  215. block.setTo(colors[i]);
  216. putText(block, classes[i], Point(0, kBlockHeight / 2), FONT_HERSHEY_SIMPLEX, 0.5, Vec3b(255, 255, 255));
  217. }
  218. namedWindow("Legend", WINDOW_NORMAL);
  219. imshow("Legend", legend);
  220. }
  221. }