obj_detect.cpp 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #include <opencv2/dnn.hpp>
  2. #include <opencv2/imgproc.hpp>
  3. #include <opencv2/highgui.hpp>
  4. #include <fstream>
  5. #include <iostream>
  6. #include <cstdlib>
  7. #include <opencv2/core_detect.hpp>
  8. using namespace cv;
  9. using namespace std;
  10. using namespace cv::dnn;
  11. using namespace cv::dnn_objdetect;
  12. int main(int argc, char **argv)
  13. {
  14. if (argc < 4)
  15. {
  16. std::cerr << "Usage " << argv[0] << ": "
  17. << "<model-definition-file> "
  18. << "<model-weights-file> "
  19. << "<test-image> "
  20. << "<threshold>(optional)\n";
  21. return -1;
  22. }
  23. std::string model_prototxt = argv[1];
  24. std::string model_binary = argv[2];
  25. std::string test_input_image = argv[3];
  26. double threshold = 0.7;
  27. if (argc == 5)
  28. {
  29. threshold = atof(argv[4]);
  30. if (threshold > 1.0 || threshold < 0.0)
  31. {
  32. std::cerr << "Threshold should belong to [0, 1]\n";
  33. return -1;
  34. }
  35. }
  36. // Load the network
  37. std::cout << "Loading the network...\n";
  38. Net net = dnn::readNetFromCaffe(model_prototxt, model_binary);
  39. if (net.empty())
  40. {
  41. std::cerr << "Couldn't load the model !\n";
  42. return -2;
  43. }
  44. else
  45. {
  46. std::cout << "Done loading the network !\n\n";
  47. }
  48. // Load the test image
  49. Mat img = cv::imread(test_input_image);
  50. Mat original_img(img);
  51. if (img.empty())
  52. {
  53. std::cerr << "Couldn't load image: " << test_input_image << "\n";
  54. return -3;
  55. }
  56. cv::namedWindow("Initial Image", WINDOW_AUTOSIZE);
  57. cv::imshow("Initial Image", img);
  58. cv::resize(img, img, cv::Size(416, 416));
  59. Mat img_copy(img);
  60. img.convertTo(img, CV_32FC3);
  61. Mat input_blob = blobFromImage(img, 1.0, Size(), cv::Scalar(104, 117, 123), false);
  62. // Set the input blob
  63. // Set the output layers
  64. std::cout << "Getting the output of all the three blobs...\n";
  65. std::vector<Mat> outblobs(3);
  66. std::vector<cv::String> out_layers;
  67. out_layers.push_back("slice");
  68. out_layers.push_back("softmax");
  69. out_layers.push_back("sigmoid");
  70. // Bbox delta blob
  71. std::vector<Mat> temp_blob;
  72. net.setInput(input_blob);
  73. cv::TickMeter t;
  74. t.start();
  75. net.forward(temp_blob, out_layers[0]);
  76. t.stop();
  77. outblobs[0] = temp_blob[2];
  78. // class_scores blob
  79. net.setInput(input_blob);
  80. t.start();
  81. outblobs[1] = net.forward(out_layers[1]);
  82. t.stop();
  83. // conf_scores blob
  84. net.setInput(input_blob);
  85. t.start();
  86. outblobs[2] = net.forward(out_layers[2]);
  87. t.stop();
  88. // Check that the blobs are valid
  89. for (size_t i = 0; i < outblobs.size(); ++i)
  90. {
  91. if (outblobs[i].empty())
  92. {
  93. std::cerr << "Blob: " << i << " is empty !\n";
  94. }
  95. }
  96. int delta_bbox_size[3] = {23, 23, 36};
  97. Mat delta_bbox(3, delta_bbox_size, CV_32F, outblobs[0].ptr<float>());
  98. int class_scores_size[2] = {4761, 20};
  99. Mat class_scores(2, class_scores_size, CV_32F, outblobs[1].ptr<float>());
  100. int conf_scores_size[3] = {23, 23, 9};
  101. Mat conf_scores(3, conf_scores_size, CV_32F, outblobs[2].ptr<float>());
  102. InferBbox inf(delta_bbox, class_scores, conf_scores);
  103. inf.filter(threshold);
  104. double average_time = t.getTimeSec() / t.getCounter();
  105. std::cout << "\nTotal objects detected: " << inf.detections.size()
  106. << " in " << average_time << " seconds\n";
  107. std::cout << "------\n";
  108. float x_ratio = (float)original_img.cols / img_copy.cols;
  109. float y_ratio = (float)original_img.rows / img_copy.rows;
  110. for (size_t i = 0; i < inf.detections.size(); ++i)
  111. {
  112. int xmin = inf.detections[i].xmin;
  113. int ymin = inf.detections[i].ymin;
  114. int xmax = inf.detections[i].xmax;
  115. int ymax = inf.detections[i].ymax;
  116. cv::String class_name = inf.detections[i].label_name;
  117. std::cout << "Class: " << class_name << "\n"
  118. << "Probability: " << inf.detections[i].class_prob << "\n"
  119. << "Co-ordinates: " << inf.detections[i].xmin << " "
  120. << inf.detections[i].ymin << " "
  121. << inf.detections[i].xmax << " "
  122. << inf.detections[i].ymax << "\n";
  123. std::cout << "------\n";
  124. // Draw the corresponding bounding box(s)
  125. cv::rectangle(original_img, cv::Point((int)(xmin * x_ratio), (int)(ymin * y_ratio)),
  126. cv::Point((int)(xmax * x_ratio), (int)(ymax * y_ratio)), cv::Scalar(255, 0, 0), 2);
  127. cv::putText(original_img, class_name, cv::Point((int)(xmin * x_ratio), (int)(ymin * y_ratio)),
  128. cv::FONT_HERSHEY_SIMPLEX, 0.7, cv::Scalar(255, 0, 0), 1);
  129. }
  130. try
  131. {
  132. cv::namedWindow("Final Detections", WINDOW_AUTOSIZE);
  133. cv::imshow("Final Detections", original_img);
  134. cv::imwrite("image.png", original_img);
  135. cv::waitKey(0);
  136. }
  137. catch (const char* msg)
  138. {
  139. std::cerr << msg << "\n";
  140. return -4;
  141. }
  142. return 0;
  143. }