person_reid.cpp 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. //
  2. // You can download a baseline ReID model and sample input from:
  3. // https://github.com/ReID-Team/ReID_extra_testdata
  4. //
  5. // Authors of samples and Youtu ReID baseline:
  6. // Xing Sun <winfredsun@tencent.com>
  7. // Feng Zheng <zhengf@sustech.edu.cn>
  8. // Xinyang Jiang <sevjiang@tencent.com>
  9. // Fufu Yu <fufuyu@tencent.com>
  10. // Enwei Zhang <miyozhang@tencent.com>
  11. //
  12. // Copyright (C) 2020-2021, Tencent.
  13. // Copyright (C) 2020-2021, SUSTech.
  14. //
  15. #include <iostream>
  16. #include <fstream>
  17. #include <opencv2/imgproc.hpp>
  18. #include <opencv2/highgui.hpp>
  19. #include <opencv2/dnn.hpp>
  20. using namespace cv;
  21. using namespace cv::dnn;
  22. const char* keys =
  23. "{help h | | show help message}"
  24. "{model m | | network model}"
  25. "{query_list q | | list of query images}"
  26. "{gallery_list g | | list of gallery images}"
  27. "{batch_size | 32 | batch size of each inference}"
  28. "{resize_h | 256 | resize input to specific height.}"
  29. "{resize_w | 128 | resize input to specific width.}"
  30. "{topk k | 5 | number of gallery images showed in visualization}"
  31. "{output_dir | | path for visualization(it should be existed)}"
  32. "{backend b | 0 | choose one of computation backends: "
  33. "0: automatically (by default), "
  34. "1: Halide language (http://halide-lang.org/), "
  35. "2: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
  36. "3: OpenCV implementation, "
  37. "4: VKCOM, "
  38. "5: CUDA }"
  39. "{target t | 0 | choose one of target computation devices: "
  40. "0: CPU target (by default), "
  41. "1: OpenCL, "
  42. "2: OpenCL fp16 (half-float precision), "
  43. "4: Vulkan, "
  44. "6: CUDA, "
  45. "7: CUDA fp16 (half-float preprocess) }";
  46. namespace cv{
  47. namespace reid{
  48. static Mat preprocess(const Mat& img)
  49. {
  50. const double mean[3] = {0.485, 0.456, 0.406};
  51. const double std[3] = {0.229, 0.224, 0.225};
  52. Mat ret = Mat(img.rows, img.cols, CV_32FC3);
  53. for (int y = 0; y < ret.rows; y ++)
  54. {
  55. for (int x = 0; x < ret.cols; x++)
  56. {
  57. for (int c = 0; c < 3; c++)
  58. {
  59. ret.at<Vec3f>(y,x)[c] = (float)((img.at<Vec3b>(y,x)[c] / 255.0 - mean[2 - c]) / std[2 - c]);
  60. }
  61. }
  62. }
  63. return ret;
  64. }
  65. static std::vector<float> normalization(const std::vector<float>& feature)
  66. {
  67. std::vector<float> ret;
  68. float sum = 0.0;
  69. for(int i = 0; i < (int)feature.size(); i++)
  70. {
  71. sum += feature[i] * feature[i];
  72. }
  73. sum = sqrt(sum);
  74. for(int i = 0; i < (int)feature.size(); i++)
  75. {
  76. ret.push_back(feature[i] / sum);
  77. }
  78. return ret;
  79. }
  80. static void extractFeatures(const std::vector<std::string>& imglist, Net* net, const int& batch_size, const int& resize_h, const int& resize_w, std::vector<std::vector<float>>& features)
  81. {
  82. for(int st = 0; st < (int)imglist.size(); st += batch_size)
  83. {
  84. std::vector<Mat> batch;
  85. for(int delta = 0; delta < batch_size && st + delta < (int)imglist.size(); delta++)
  86. {
  87. Mat img = imread(imglist[st + delta]);
  88. batch.push_back(preprocess(img));
  89. }
  90. Mat blob = dnn::blobFromImages(batch, 1.0, Size(resize_w, resize_h), Scalar(0.0,0.0,0.0), true, false, CV_32F);
  91. net->setInput(blob);
  92. Mat out = net->forward();
  93. for(int i = 0; i < (int)out.size().height; i++)
  94. {
  95. std::vector<float> temp_feature;
  96. for(int j = 0; j < (int)out.size().width; j++)
  97. {
  98. temp_feature.push_back(out.at<float>(i,j));
  99. }
  100. features.push_back(normalization(temp_feature));
  101. }
  102. }
  103. return ;
  104. }
  105. static void getNames(const std::string& ImageList, std::vector<std::string>& result)
  106. {
  107. std::ifstream img_in(ImageList);
  108. std::string img_name;
  109. while(img_in >> img_name)
  110. {
  111. result.push_back(img_name);
  112. }
  113. return ;
  114. }
  115. static float similarity(const std::vector<float>& feature1, const std::vector<float>& feature2)
  116. {
  117. float result = 0.0;
  118. for(int i = 0; i < (int)feature1.size(); i++)
  119. {
  120. result += feature1[i] * feature2[i];
  121. }
  122. return result;
  123. }
  124. static void getTopK(const std::vector<std::vector<float>>& queryFeatures, const std::vector<std::vector<float>>& galleryFeatures, const int& topk, std::vector<std::vector<int>>& result)
  125. {
  126. for(int i = 0; i < (int)queryFeatures.size(); i++)
  127. {
  128. std::vector<float> similarityList;
  129. std::vector<int> index;
  130. for(int j = 0; j < (int)galleryFeatures.size(); j++)
  131. {
  132. similarityList.push_back(similarity(queryFeatures[i], galleryFeatures[j]));
  133. index.push_back(j);
  134. }
  135. sort(index.begin(), index.end(), [&](int x,int y){return similarityList[x] > similarityList[y];});
  136. std::vector<int> topk_result;
  137. for(int j = 0; j < min(topk, (int)index.size()); j++)
  138. {
  139. topk_result.push_back(index[j]);
  140. }
  141. result.push_back(topk_result);
  142. }
  143. return ;
  144. }
  145. static void addBorder(const Mat& img, const Scalar& color, Mat& result)
  146. {
  147. const int bordersize = 5;
  148. copyMakeBorder(img, result, bordersize, bordersize, bordersize, bordersize, cv::BORDER_CONSTANT, color);
  149. return ;
  150. }
  151. static void drawRankList(const std::string& queryName, const std::vector<std::string>& galleryImageNames, const std::vector<int>& topk_index, const int& resize_h, const int& resize_w, Mat& result)
  152. {
  153. const Size outputSize = Size(resize_w, resize_h);
  154. Mat q_img = imread(queryName), temp_img;
  155. resize(q_img, temp_img, outputSize);
  156. addBorder(temp_img, Scalar(0,0,0), q_img);
  157. putText(q_img, "Query", Point(10, 30), FONT_HERSHEY_COMPLEX, 1.0, Scalar(0,255,0), 2);
  158. std::vector<Mat> Images;
  159. Images.push_back(q_img);
  160. for(int i = 0; i < (int)topk_index.size(); i++)
  161. {
  162. Mat g_img = imread(galleryImageNames[topk_index[i]]);
  163. resize(g_img, temp_img, outputSize);
  164. addBorder(temp_img, Scalar(255,255,255), g_img);
  165. putText(g_img, "G" + std::to_string(i), Point(10, 30), FONT_HERSHEY_COMPLEX, 1.0, Scalar(0,255,0), 2);
  166. Images.push_back(g_img);
  167. }
  168. hconcat(Images, result);
  169. return ;
  170. }
  171. static void visualization(const std::vector<std::vector<int>>& topk, const std::vector<std::string>& queryImageNames, const std::vector<std::string>& galleryImageNames, const std::string& output_dir, const int& resize_h, const int& resize_w)
  172. {
  173. for(int i = 0; i < (int)queryImageNames.size(); i++)
  174. {
  175. Mat img;
  176. drawRankList(queryImageNames[i], galleryImageNames, topk[i], resize_h, resize_w, img);
  177. std::string output_path = output_dir + "/" + queryImageNames[i].substr(queryImageNames[i].rfind("/")+1);
  178. imwrite(output_path, img);
  179. }
  180. return ;
  181. }
  182. };
  183. };
  184. int main(int argc, char** argv)
  185. {
  186. // Parse command line arguments.
  187. CommandLineParser parser(argc, argv, keys);
  188. if (argc == 1 || parser.has("help"))
  189. {
  190. parser.printMessage();
  191. return 0;
  192. }
  193. parser = CommandLineParser(argc, argv, keys);
  194. parser.about("Use this script to run ReID networks using OpenCV.");
  195. const std::string modelPath = parser.get<String>("model");
  196. const std::string queryImageList = parser.get<String>("query_list");
  197. const std::string galleryImageList = parser.get<String>("gallery_list");
  198. const int backend = parser.get<int>("backend");
  199. const int target = parser.get<int>("target");
  200. const int batch_size = parser.get<int>("batch_size");
  201. const int resize_h = parser.get<int>("resize_h");
  202. const int resize_w = parser.get<int>("resize_w");
  203. const int topk = parser.get<int>("topk");
  204. const std::string output_dir= parser.get<String>("output_dir");
  205. std::vector<std::string> queryImageNames;
  206. reid::getNames(queryImageList, queryImageNames);
  207. std::vector<std::string> galleryImageNames;
  208. reid::getNames(galleryImageList, galleryImageNames);
  209. dnn::Net net = dnn::readNet(modelPath);
  210. net.setPreferableBackend(backend);
  211. net.setPreferableTarget(target);
  212. std::vector<std::vector<float>> queryFeatures;
  213. reid::extractFeatures(queryImageNames, &net, batch_size, resize_h, resize_w, queryFeatures);
  214. std::vector<std::vector<float>> galleryFeatures;
  215. reid::extractFeatures(galleryImageNames, &net, batch_size, resize_h, resize_w, galleryFeatures);
  216. std::vector<std::vector<int>> topkResult;
  217. reid::getTopK(queryFeatures, galleryFeatures, topk, topkResult);
  218. reid::visualization(topkResult, queryImageNames, galleryImageNames, output_dir, resize_h, resize_w);
  219. return 0;
  220. }