mobilenet_ssd_accuracy.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from __future__ import print_function
  2. # Script to evaluate MobileNet-SSD object detection model trained in TensorFlow
  3. # using both TensorFlow and OpenCV. Example:
  4. #
  5. # python mobilenet_ssd_accuracy.py \
  6. # --weights=frozen_inference_graph.pb \
  7. # --prototxt=ssd_mobilenet_v1_coco.pbtxt \
  8. # --images=val2017 \
  9. # --annotations=annotations/instances_val2017.json
  10. #
  11. # Tested on COCO 2017 object detection dataset, http://cocodataset.org/#download
  12. import os
  13. import cv2 as cv
  14. import json
  15. import argparse
  16. parser = argparse.ArgumentParser(
  17. description='Evaluate MobileNet-SSD model using both TensorFlow and OpenCV. '
  18. 'COCO evaluation framework is required: http://cocodataset.org')
  19. parser.add_argument('--weights', required=True,
  20. help='Path to frozen_inference_graph.pb of MobileNet-SSD model. '
  21. 'Download it from http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz')
  22. parser.add_argument('--prototxt', help='Path to ssd_mobilenet_v1_coco.pbtxt from opencv_extra.', required=True)
  23. parser.add_argument('--images', help='Path to COCO validation images directory.', required=True)
  24. parser.add_argument('--annotations', help='Path to COCO annotations file.', required=True)
  25. args = parser.parse_args()
  26. ### Get OpenCV predictions #####################################################
  27. net = cv.dnn.readNetFromTensorflow(cv.samples.findFile(args.weights), cv.samples.findFile(args.prototxt))
  28. net.setPreferableBackend(cv.dnn.DNN_BACKEND_OPENCV)
  29. detections = []
  30. for imgName in os.listdir(args.images):
  31. inp = cv.imread(cv.samples.findFile(os.path.join(args.images, imgName)))
  32. rows = inp.shape[0]
  33. cols = inp.shape[1]
  34. inp = cv.resize(inp, (300, 300))
  35. net.setInput(cv.dnn.blobFromImage(inp, 1.0/127.5, (300, 300), (127.5, 127.5, 127.5), True))
  36. out = net.forward()
  37. for i in range(out.shape[2]):
  38. score = float(out[0, 0, i, 2])
  39. # Confidence threshold is in prototxt.
  40. classId = int(out[0, 0, i, 1])
  41. x = out[0, 0, i, 3] * cols
  42. y = out[0, 0, i, 4] * rows
  43. w = out[0, 0, i, 5] * cols - x
  44. h = out[0, 0, i, 6] * rows - y
  45. detections.append({
  46. "image_id": int(imgName.rstrip('0')[:imgName.rfind('.')]),
  47. "category_id": classId,
  48. "bbox": [x, y, w, h],
  49. "score": score
  50. })
  51. with open('cv_result.json', 'wt') as f:
  52. json.dump(detections, f)
  53. ### Get TensorFlow predictions #################################################
  54. import tensorflow as tf
  55. with tf.gfile.FastGFile(args.weights) as f:
  56. # Load the model
  57. graph_def = tf.GraphDef()
  58. graph_def.ParseFromString(f.read())
  59. with tf.Session() as sess:
  60. # Restore session
  61. sess.graph.as_default()
  62. tf.import_graph_def(graph_def, name='')
  63. detections = []
  64. for imgName in os.listdir(args.images):
  65. inp = cv.imread(os.path.join(args.images, imgName))
  66. rows = inp.shape[0]
  67. cols = inp.shape[1]
  68. inp = cv.resize(inp, (300, 300))
  69. inp = inp[:, :, [2, 1, 0]] # BGR2RGB
  70. out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
  71. sess.graph.get_tensor_by_name('detection_scores:0'),
  72. sess.graph.get_tensor_by_name('detection_boxes:0'),
  73. sess.graph.get_tensor_by_name('detection_classes:0')],
  74. feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)})
  75. num_detections = int(out[0][0])
  76. for i in range(num_detections):
  77. classId = int(out[3][0][i])
  78. score = float(out[1][0][i])
  79. bbox = [float(v) for v in out[2][0][i]]
  80. if score > 0.01:
  81. x = bbox[1] * cols
  82. y = bbox[0] * rows
  83. w = bbox[3] * cols - x
  84. h = bbox[2] * rows - y
  85. detections.append({
  86. "image_id": int(imgName.rstrip('0')[:imgName.rfind('.')]),
  87. "category_id": classId,
  88. "bbox": [x, y, w, h],
  89. "score": score
  90. })
  91. with open('tf_result.json', 'wt') as f:
  92. json.dump(detections, f)
  93. ### Evaluation part ############################################################
  94. # %matplotlib inline
  95. import matplotlib.pyplot as plt
  96. from pycocotools.coco import COCO
  97. from pycocotools.cocoeval import COCOeval
  98. import numpy as np
  99. import skimage.io as io
  100. import pylab
  101. pylab.rcParams['figure.figsize'] = (10.0, 8.0)
  102. annType = ['segm','bbox','keypoints']
  103. annType = annType[1] #specify type here
  104. prefix = 'person_keypoints' if annType=='keypoints' else 'instances'
  105. print('Running demo for *%s* results.'%(annType))
  106. #initialize COCO ground truth api
  107. cocoGt=COCO(args.annotations)
  108. #initialize COCO detections api
  109. for resFile in ['tf_result.json', 'cv_result.json']:
  110. print(resFile)
  111. cocoDt=cocoGt.loadRes(resFile)
  112. cocoEval = COCOeval(cocoGt,cocoDt,annType)
  113. cocoEval.evaluate()
  114. cocoEval.accumulate()
  115. cocoEval.summarize()