person_reid.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. #!/usr/bin/env python
  2. '''
  3. You can download a baseline ReID model and sample input from:
  4. https://github.com/ReID-Team/ReID_extra_testdata
  5. Authors of samples and Youtu ReID baseline:
  6. Xing Sun <winfredsun@tencent.com>
  7. Feng Zheng <zhengf@sustech.edu.cn>
  8. Xinyang Jiang <sevjiang@tencent.com>
  9. Fufu Yu <fufuyu@tencent.com>
  10. Enwei Zhang <miyozhang@tencent.com>
  11. Copyright (C) 2020-2021, Tencent.
  12. Copyright (C) 2020-2021, SUSTech.
  13. '''
  14. import argparse
  15. import os.path
  16. import numpy as np
  17. import cv2 as cv
  18. backends = (cv.dnn.DNN_BACKEND_DEFAULT,
  19. cv.dnn.DNN_BACKEND_INFERENCE_ENGINE,
  20. cv.dnn.DNN_BACKEND_OPENCV,
  21. cv.dnn.DNN_BACKEND_VKCOM,
  22. cv.dnn.DNN_BACKEND_CUDA)
  23. targets = (cv.dnn.DNN_TARGET_CPU,
  24. cv.dnn.DNN_TARGET_OPENCL,
  25. cv.dnn.DNN_TARGET_OPENCL_FP16,
  26. cv.dnn.DNN_TARGET_MYRIAD,
  27. cv.dnn.DNN_TARGET_HDDL,
  28. cv.dnn.DNN_TARGET_VULKAN,
  29. cv.dnn.DNN_TARGET_CUDA,
  30. cv.dnn.DNN_TARGET_CUDA_FP16)
  31. MEAN = (0.485, 0.456, 0.406)
  32. STD = (0.229, 0.224, 0.225)
  33. def preprocess(images, height, width):
  34. """
  35. Create 4-dimensional blob from image
  36. :param image: input image
  37. :param height: the height of the resized input image
  38. :param width: the width of the resized input image
  39. """
  40. img_list = []
  41. for image in images:
  42. image = cv.resize(image, (width, height))
  43. img_list.append(image[:, :, ::-1])
  44. images = np.array(img_list)
  45. images = (images / 255.0 - MEAN) / STD
  46. input = cv.dnn.blobFromImages(images.astype(np.float32), ddepth = cv.CV_32F)
  47. return input
  48. def extract_feature(img_dir, model_path, batch_size = 32, resize_h = 384, resize_w = 128, backend=cv.dnn.DNN_BACKEND_OPENCV, target=cv.dnn.DNN_TARGET_CPU):
  49. """
  50. Extract features from images in a target directory
  51. :param img_dir: the input image directory
  52. :param model_path: path to ReID model
  53. :param batch_size: the batch size for each network inference iteration
  54. :param resize_h: the height of the input image
  55. :param resize_w: the width of the input image
  56. :param backend: name of computation backend
  57. :param target: name of computation target
  58. """
  59. feat_list = []
  60. path_list = os.listdir(img_dir)
  61. path_list = [os.path.join(img_dir, img_name) for img_name in path_list]
  62. count = 0
  63. for i in range(0, len(path_list), batch_size):
  64. print('Feature Extraction for images in', img_dir, 'Batch:', count, '/', len(path_list))
  65. batch = path_list[i : min(i + batch_size, len(path_list))]
  66. imgs = read_data(batch)
  67. inputs = preprocess(imgs, resize_h, resize_w)
  68. feat = run_net(inputs, model_path, backend, target)
  69. feat_list.append(feat)
  70. count += batch_size
  71. feats = np.concatenate(feat_list, axis = 0)
  72. return feats, path_list
  73. def run_net(inputs, model_path, backend=cv.dnn.DNN_BACKEND_OPENCV, target=cv.dnn.DNN_TARGET_CPU):
  74. """
  75. Forword propagation for a batch of images.
  76. :param inputs: input batch of images
  77. :param model_path: path to ReID model
  78. :param backend: name of computation backend
  79. :param target: name of computation target
  80. """
  81. net = cv.dnn.readNet(model_path)
  82. net.setPreferableBackend(backend)
  83. net.setPreferableTarget(target)
  84. net.setInput(inputs)
  85. out = net.forward()
  86. out = np.reshape(out, (out.shape[0], out.shape[1]))
  87. return out
  88. def read_data(path_list):
  89. """
  90. Read all images from a directory into a list
  91. :param path_list: the list of image path
  92. """
  93. img_list = []
  94. for img_path in path_list:
  95. img = cv.imread(img_path)
  96. if img is None:
  97. continue
  98. img_list.append(img)
  99. return img_list
  100. def normalize(nparray, order=2, axis=0):
  101. """
  102. Normalize a N-D numpy array along the specified axis.
  103. :param nparry: the array of vectors to be normalized
  104. :param order: order of the norm
  105. :param axis: the axis of x along which to compute the vector norms
  106. """
  107. norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
  108. return nparray / (norm + np.finfo(np.float32).eps)
  109. def similarity(array1, array2):
  110. """
  111. Compute the euclidean or cosine distance of all pairs.
  112. :param array1: numpy array with shape [m1, n]
  113. :param array2: numpy array with shape [m2, n]
  114. Returns:
  115. numpy array with shape [m1, m2]
  116. """
  117. array1 = normalize(array1, axis=1)
  118. array2 = normalize(array2, axis=1)
  119. dist = np.matmul(array1, array2.T)
  120. return dist
  121. def topk(query_feat, gallery_feat, topk = 5):
  122. """
  123. Return the index of top K gallery images most similar to the query images
  124. :param query_feat: array of feature vectors of query images
  125. :param gallery_feat: array of feature vectors of gallery images
  126. :param topk: number of gallery images to return
  127. """
  128. sim = similarity(query_feat, gallery_feat)
  129. index = np.argsort(-sim, axis = 1)
  130. return [i[0:int(topk)] for i in index]
  131. def drawRankList(query_name, gallery_list, output_size = (128, 384)):
  132. """
  133. Draw the rank list
  134. :param query_name: path of the query image
  135. :param gallery_name: path of the gallery image
  136. "param output_size: the output size of each image in the rank list
  137. """
  138. def addBorder(im, color):
  139. bordersize = 5
  140. border = cv.copyMakeBorder(
  141. im,
  142. top = bordersize,
  143. bottom = bordersize,
  144. left = bordersize,
  145. right = bordersize,
  146. borderType = cv.BORDER_CONSTANT,
  147. value = color
  148. )
  149. return border
  150. query_img = cv.imread(query_name)
  151. query_img = cv.resize(query_img, output_size)
  152. query_img = addBorder(query_img, [0, 0, 0])
  153. cv.putText(query_img, 'Query', (10, 30), cv.FONT_HERSHEY_COMPLEX, 1., (0,255,0), 2)
  154. gallery_img_list = []
  155. for i, gallery_name in enumerate(gallery_list):
  156. gallery_img = cv.imread(gallery_name)
  157. gallery_img = cv.resize(gallery_img, output_size)
  158. gallery_img = addBorder(gallery_img, [255, 255, 255])
  159. cv.putText(gallery_img, 'G%02d'%i, (10, 30), cv.FONT_HERSHEY_COMPLEX, 1., (0,255,0), 2)
  160. gallery_img_list.append(gallery_img)
  161. ret = np.concatenate([query_img] + gallery_img_list, axis = 1)
  162. return ret
  163. def visualization(topk_idx, query_names, gallery_names, output_dir = 'vis'):
  164. """
  165. Visualize the retrieval results with the person ReID model
  166. :param topk_idx: the index of ranked gallery images for each query image
  167. :param query_names: the list of paths of query images
  168. :param gallery_names: the list of paths of gallery images
  169. :param output_dir: the path to save the visualize results
  170. """
  171. if not os.path.exists(output_dir):
  172. os.mkdir(output_dir)
  173. for i, idx in enumerate(topk_idx):
  174. query_name = query_names[i]
  175. topk_names = [gallery_names[j] for j in idx]
  176. vis_img = drawRankList(query_name, topk_names)
  177. output_path = os.path.join(output_dir, '%03d_%s'%(i, os.path.basename(query_name)))
  178. cv.imwrite(output_path, vis_img)
  179. if __name__ == '__main__':
  180. parser = argparse.ArgumentParser(description='Use this script to run human parsing using JPPNet',
  181. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  182. parser.add_argument('--query_dir', '-q', required=True, help='Path to query image.')
  183. parser.add_argument('--gallery_dir', '-g', required=True, help='Path to gallery directory.')
  184. parser.add_argument('--resize_h', default = 256, help='The height of the input for model inference.')
  185. parser.add_argument('--resize_w', default = 128, help='The width of the input for model inference')
  186. parser.add_argument('--model', '-m', default='reid.onnx', help='Path to pb model.')
  187. parser.add_argument('--visualization_dir', default='vis', help='Path for the visualization results')
  188. parser.add_argument('--topk', default=10, help='Number of images visualized in the rank list')
  189. parser.add_argument('--batchsize', default=32, help='The batch size of each inference')
  190. parser.add_argument('--backend', choices=backends, default=cv.dnn.DNN_BACKEND_DEFAULT, type=int,
  191. help="Choose one of computation backends: "
  192. "%d: automatically (by default), "
  193. "%d: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
  194. "%d: OpenCV implementation, "
  195. "%d: VKCOM, "
  196. "%d: CUDA backend"% backends)
  197. parser.add_argument('--target', choices=targets, default=cv.dnn.DNN_TARGET_CPU, type=int,
  198. help='Choose one of target computation devices: '
  199. '%d: CPU target (by default), '
  200. '%d: OpenCL, '
  201. '%d: OpenCL fp16 (half-float precision), '
  202. '%d: NCS2 VPU, '
  203. '%d: HDDL VPU, '
  204. '%d: Vulkan, '
  205. '%d: CUDA, '
  206. '%d: CUDA FP16'
  207. % targets)
  208. args, _ = parser.parse_known_args()
  209. if not os.path.isfile(args.model):
  210. raise OSError("Model not exist")
  211. query_feat, query_names = extract_feature(args.query_dir, args.model, args.batchsize, args.resize_h, args.resize_w, args.backend, args.target)
  212. gallery_feat, gallery_names = extract_feature(args.gallery_dir, args.model, args.batchsize, args.resize_h, args.resize_w, args.backend, args.target)
  213. topk_idx = topk(query_feat, gallery_feat, args.topk)
  214. visualization(topk_idx, query_names, gallery_names, output_dir = args.visualization_dir)