semantic_segmentation.cpp 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. #include <opencv2/imgproc.hpp>
  2. #include <opencv2/gapi/infer/ie.hpp>
  3. #include <opencv2/gapi/cpu/gcpukernel.hpp>
  4. #include <opencv2/gapi/streaming/cap.hpp>
  5. #include <opencv2/gapi/operators.hpp>
  6. #include <opencv2/highgui.hpp>
  7. const std::string keys =
  8. "{ h help | | Print this help message }"
  9. "{ input | | Path to the input video file }"
  10. "{ output | | Path to the output video file }"
  11. "{ ssm | semantic-segmentation-adas-0001.xml | Path to OpenVINO IE semantic segmentation model (.xml) }";
  12. // 20 colors for 20 classes of semantic-segmentation-adas-0001
  13. const std::vector<cv::Vec3b> colors = {
  14. { 128, 64, 128 },
  15. { 232, 35, 244 },
  16. { 70, 70, 70 },
  17. { 156, 102, 102 },
  18. { 153, 153, 190 },
  19. { 153, 153, 153 },
  20. { 30, 170, 250 },
  21. { 0, 220, 220 },
  22. { 35, 142, 107 },
  23. { 152, 251, 152 },
  24. { 180, 130, 70 },
  25. { 60, 20, 220 },
  26. { 0, 0, 255 },
  27. { 142, 0, 0 },
  28. { 70, 0, 0 },
  29. { 100, 60, 0 },
  30. { 90, 0, 0 },
  31. { 230, 0, 0 },
  32. { 32, 11, 119 },
  33. { 0, 74, 111 },
  34. };
  35. namespace {
  36. std::string get_weights_path(const std::string &model_path) {
  37. const auto EXT_LEN = 4u;
  38. const auto sz = model_path.size();
  39. CV_Assert(sz > EXT_LEN);
  40. auto ext = model_path.substr(sz - EXT_LEN);
  41. std::transform(ext.begin(), ext.end(), ext.begin(), [](unsigned char c){
  42. return static_cast<unsigned char>(std::tolower(c));
  43. });
  44. CV_Assert(ext == ".xml");
  45. return model_path.substr(0u, sz - EXT_LEN) + ".bin";
  46. }
  47. void classesToColors(const cv::Mat &out_blob,
  48. cv::Mat &mask_img) {
  49. const int H = out_blob.size[0];
  50. const int W = out_blob.size[1];
  51. mask_img.create(H, W, CV_8UC3);
  52. GAPI_Assert(out_blob.type() == CV_8UC1);
  53. const uint8_t* const classes = out_blob.ptr<uint8_t>();
  54. for (int rowId = 0; rowId < H; ++rowId) {
  55. for (int colId = 0; colId < W; ++colId) {
  56. uint8_t class_id = classes[rowId * W + colId];
  57. mask_img.at<cv::Vec3b>(rowId, colId) =
  58. class_id < colors.size()
  59. ? colors[class_id]
  60. : cv::Vec3b{0, 0, 0}; // NB: sample supports 20 classes
  61. }
  62. }
  63. }
  64. void probsToClasses(const cv::Mat& probs, cv::Mat& classes) {
  65. const int C = probs.size[1];
  66. const int H = probs.size[2];
  67. const int W = probs.size[3];
  68. classes.create(H, W, CV_8UC1);
  69. GAPI_Assert(probs.depth() == CV_32F);
  70. float* out_p = reinterpret_cast<float*>(probs.data);
  71. uint8_t* classes_p = reinterpret_cast<uint8_t*>(classes.data);
  72. for (int h = 0; h < H; ++h) {
  73. for (int w = 0; w < W; ++w) {
  74. double max = 0;
  75. int class_id = 0;
  76. for (int c = 0; c < C; ++c) {
  77. int idx = c * H * W + h * W + w;
  78. if (out_p[idx] > max) {
  79. max = out_p[idx];
  80. class_id = c;
  81. }
  82. }
  83. classes_p[h * W + w] = static_cast<uint8_t>(class_id);
  84. }
  85. }
  86. }
  87. } // anonymous namespace
  88. namespace custom {
  89. G_API_OP(PostProcessing, <cv::GMat(cv::GMat, cv::GMat)>, "sample.custom.post_processing") {
  90. static cv::GMatDesc outMeta(const cv::GMatDesc &in, const cv::GMatDesc &) {
  91. return in;
  92. }
  93. };
  94. GAPI_OCV_KERNEL(OCVPostProcessing, PostProcessing) {
  95. static void run(const cv::Mat &in, const cv::Mat &out_blob, cv::Mat &out) {
  96. cv::Mat classes;
  97. // NB: If output has more than single plane, it contains probabilities
  98. // otherwise class id.
  99. if (out_blob.size[1] > 1) {
  100. probsToClasses(out_blob, classes);
  101. } else {
  102. out_blob.convertTo(classes, CV_8UC1);
  103. classes = classes.reshape(1, out_blob.size[2]);
  104. }
  105. cv::Mat mask_img;
  106. classesToColors(classes, mask_img);
  107. cv::resize(mask_img, out, in.size());
  108. }
  109. };
  110. } // namespace custom
  111. int main(int argc, char *argv[]) {
  112. cv::CommandLineParser cmd(argc, argv, keys);
  113. if (cmd.has("help")) {
  114. cmd.printMessage();
  115. return 0;
  116. }
  117. // Prepare parameters first
  118. const std::string input = cmd.get<std::string>("input");
  119. const std::string output = cmd.get<std::string>("output");
  120. const auto model_path = cmd.get<std::string>("ssm");
  121. const auto weights_path = get_weights_path(model_path);
  122. const auto device = "CPU";
  123. G_API_NET(SemSegmNet, <cv::GMat(cv::GMat)>, "semantic-segmentation");
  124. const auto net = cv::gapi::ie::Params<SemSegmNet> {
  125. model_path, weights_path, device
  126. };
  127. const auto kernels = cv::gapi::kernels<custom::OCVPostProcessing>();
  128. const auto networks = cv::gapi::networks(net);
  129. // Now build the graph
  130. cv::GMat in;
  131. cv::GMat out_blob = cv::gapi::infer<SemSegmNet>(in);
  132. cv::GMat post_proc_out = custom::PostProcessing::on(in, out_blob);
  133. cv::GMat blending_in = in * 0.3f;
  134. cv::GMat blending_out = post_proc_out * 0.7f;
  135. cv::GMat out = blending_in + blending_out;
  136. cv::GStreamingCompiled pipeline = cv::GComputation(cv::GIn(in), cv::GOut(out))
  137. .compileStreaming(cv::compile_args(kernels, networks));
  138. auto inputs = cv::gin(cv::gapi::wip::make_src<cv::gapi::wip::GCaptureSource>(input));
  139. // The execution part
  140. pipeline.setSource(std::move(inputs));
  141. cv::VideoWriter writer;
  142. cv::TickMeter tm;
  143. cv::Mat outMat;
  144. std::size_t frames = 0u;
  145. tm.start();
  146. pipeline.start();
  147. while (pipeline.pull(cv::gout(outMat))) {
  148. ++frames;
  149. cv::imshow("Out", outMat);
  150. cv::waitKey(1);
  151. if (!output.empty()) {
  152. if (!writer.isOpened()) {
  153. const auto sz = cv::Size{outMat.cols, outMat.rows};
  154. writer.open(output, cv::VideoWriter::fourcc('M','J','P','G'), 25.0, sz);
  155. CV_Assert(writer.isOpened());
  156. }
  157. writer << outMat;
  158. }
  159. }
  160. tm.stop();
  161. std::cout << "Processed " << frames << " frames" << " (" << frames / tm.getTimeSec() << " FPS)" << std::endl;
  162. return 0;
  163. }