imagenet_cls_test_alexnet.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. from __future__ import print_function
  2. from abc import ABCMeta, abstractmethod
  3. import numpy as np
  4. import sys
  5. import os
  6. import argparse
  7. import time
  8. try:
  9. import caffe
  10. except ImportError:
  11. raise ImportError('Can\'t find Caffe Python module. If you\'ve built it from sources without installation, '
  12. 'configure environment variable PYTHONPATH to "git/caffe/python" directory')
  13. try:
  14. import cv2 as cv
  15. except ImportError:
  16. raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '
  17. 'configure environment variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')
  18. try:
  19. xrange # Python 2
  20. except NameError:
  21. xrange = range # Python 3
  22. class DataFetch(object):
  23. imgs_dir = ''
  24. frame_size = 0
  25. bgr_to_rgb = False
  26. __metaclass__ = ABCMeta
  27. @abstractmethod
  28. def preprocess(self, img):
  29. pass
  30. def get_batch(self, imgs_names):
  31. assert type(imgs_names) is list
  32. batch = np.zeros((len(imgs_names), 3, self.frame_size, self.frame_size)).astype(np.float32)
  33. for i in range(len(imgs_names)):
  34. img_name = imgs_names[i]
  35. img_file = self.imgs_dir + img_name
  36. assert os.path.exists(img_file)
  37. img = cv.imread(img_file, cv.IMREAD_COLOR)
  38. min_dim = min(img.shape[-3], img.shape[-2])
  39. resize_ratio = self.frame_size / float(min_dim)
  40. img = cv.resize(img, (0, 0), fx=resize_ratio, fy=resize_ratio)
  41. cols = img.shape[1]
  42. rows = img.shape[0]
  43. y1 = (rows - self.frame_size) / 2
  44. y2 = y1 + self.frame_size
  45. x1 = (cols - self.frame_size) / 2
  46. x2 = x1 + self.frame_size
  47. img = img[y1:y2, x1:x2]
  48. if self.bgr_to_rgb:
  49. img = img[..., ::-1]
  50. image_data = img[:, :, 0:3].transpose(2, 0, 1)
  51. batch[i] = self.preprocess(image_data)
  52. return batch
  53. class MeanBlobFetch(DataFetch):
  54. mean_blob = np.ndarray(())
  55. def __init__(self, frame_size, mean_blob_path, imgs_dir):
  56. self.imgs_dir = imgs_dir
  57. self.frame_size = frame_size
  58. blob = caffe.proto.caffe_pb2.BlobProto()
  59. data = open(mean_blob_path, 'rb').read()
  60. blob.ParseFromString(data)
  61. self.mean_blob = np.array(caffe.io.blobproto_to_array(blob))
  62. start = (self.mean_blob.shape[2] - self.frame_size) / 2
  63. stop = start + self.frame_size
  64. self.mean_blob = self.mean_blob[:, :, start:stop, start:stop][0]
  65. def preprocess(self, img):
  66. return img - self.mean_blob
  67. class MeanChannelsFetch(MeanBlobFetch):
  68. def __init__(self, frame_size, imgs_dir):
  69. self.imgs_dir = imgs_dir
  70. self.frame_size = frame_size
  71. self.mean_blob = np.ones((3, self.frame_size, self.frame_size)).astype(np.float32)
  72. self.mean_blob[0] *= 104
  73. self.mean_blob[1] *= 117
  74. self.mean_blob[2] *= 123
  75. class MeanValueFetch(MeanBlobFetch):
  76. def __init__(self, frame_size, imgs_dir, bgr_to_rgb):
  77. self.imgs_dir = imgs_dir
  78. self.frame_size = frame_size
  79. self.mean_blob = np.ones((3, self.frame_size, self.frame_size)).astype(np.float32)
  80. self.mean_blob *= 117
  81. self.bgr_to_rgb = bgr_to_rgb
  82. def get_correct_answers(img_list, img_classes, net_output_blob):
  83. correct_answers = 0
  84. for i in range(len(img_list)):
  85. indexes = np.argsort(net_output_blob[i])[-5:]
  86. correct_index = img_classes[img_list[i]]
  87. if correct_index in indexes:
  88. correct_answers += 1
  89. return correct_answers
  90. class Framework(object):
  91. in_blob_name = ''
  92. out_blob_name = ''
  93. __metaclass__ = ABCMeta
  94. @abstractmethod
  95. def get_name(self):
  96. pass
  97. @abstractmethod
  98. def get_output(self, input_blob):
  99. pass
  100. class CaffeModel(Framework):
  101. net = caffe.Net
  102. need_reshape = False
  103. def __init__(self, prototxt, caffemodel, in_blob_name, out_blob_name, need_reshape=False):
  104. caffe.set_mode_cpu()
  105. self.net = caffe.Net(prototxt, caffemodel, caffe.TEST)
  106. self.in_blob_name = in_blob_name
  107. self.out_blob_name = out_blob_name
  108. self.need_reshape = need_reshape
  109. def get_name(self):
  110. return 'Caffe'
  111. def get_output(self, input_blob):
  112. if self.need_reshape:
  113. self.net.blobs[self.in_blob_name].reshape(*input_blob.shape)
  114. return self.net.forward_all(**{self.in_blob_name: input_blob})[self.out_blob_name]
  115. class DnnCaffeModel(Framework):
  116. net = object
  117. def __init__(self, prototxt, caffemodel, in_blob_name, out_blob_name):
  118. self.net = cv.dnn.readNetFromCaffe(prototxt, caffemodel)
  119. self.in_blob_name = in_blob_name
  120. self.out_blob_name = out_blob_name
  121. def get_name(self):
  122. return 'DNN'
  123. def get_output(self, input_blob):
  124. self.net.setInput(input_blob, self.in_blob_name)
  125. return self.net.forward(self.out_blob_name)
  126. class ClsAccEvaluation:
  127. log = sys.stdout
  128. img_classes = {}
  129. batch_size = 0
  130. def __init__(self, log_path, img_classes_file, batch_size):
  131. self.log = open(log_path, 'w')
  132. self.img_classes = self.read_classes(img_classes_file)
  133. self.batch_size = batch_size
  134. @staticmethod
  135. def read_classes(img_classes_file):
  136. result = {}
  137. with open(img_classes_file) as file:
  138. for l in file.readlines():
  139. result[l.split()[0]] = int(l.split()[1])
  140. return result
  141. def process(self, frameworks, data_fetcher):
  142. sorted_imgs_names = sorted(self.img_classes.keys())
  143. correct_answers = [0] * len(frameworks)
  144. samples_handled = 0
  145. blobs_l1_diff = [0] * len(frameworks)
  146. blobs_l1_diff_count = [0] * len(frameworks)
  147. blobs_l_inf_diff = [sys.float_info.min] * len(frameworks)
  148. inference_time = [0.0] * len(frameworks)
  149. for x in xrange(0, len(sorted_imgs_names), self.batch_size):
  150. sublist = sorted_imgs_names[x:x + self.batch_size]
  151. batch = data_fetcher.get_batch(sublist)
  152. samples_handled += len(sublist)
  153. frameworks_out = []
  154. fw_accuracy = []
  155. for i in range(len(frameworks)):
  156. start = time.time()
  157. out = frameworks[i].get_output(batch)
  158. end = time.time()
  159. correct_answers[i] += get_correct_answers(sublist, self.img_classes, out)
  160. fw_accuracy.append(100 * correct_answers[i] / float(samples_handled))
  161. frameworks_out.append(out)
  162. inference_time[i] += end - start
  163. print(samples_handled, 'Accuracy for', frameworks[i].get_name() + ':', fw_accuracy[i], file=self.log)
  164. print("Inference time, ms ", \
  165. frameworks[i].get_name(), inference_time[i] / samples_handled * 1000, file=self.log)
  166. for i in range(1, len(frameworks)):
  167. log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'
  168. diff = np.abs(frameworks_out[0] - frameworks_out[i])
  169. l1_diff = np.sum(diff) / diff.size
  170. print(samples_handled, "L1 difference", log_str, l1_diff, file=self.log)
  171. blobs_l1_diff[i] += l1_diff
  172. blobs_l1_diff_count[i] += 1
  173. if np.max(diff) > blobs_l_inf_diff[i]:
  174. blobs_l_inf_diff[i] = np.max(diff)
  175. print(samples_handled, "L_INF difference", log_str, blobs_l_inf_diff[i], file=self.log)
  176. self.log.flush()
  177. for i in range(1, len(blobs_l1_diff)):
  178. log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'
  179. print('Final l1 diff', log_str, blobs_l1_diff[i] / blobs_l1_diff_count[i], file=self.log)
  180. if __name__ == "__main__":
  181. parser = argparse.ArgumentParser()
  182. parser.add_argument("--imgs_dir", help="path to ImageNet validation subset images dir, ILSVRC2012_img_val dir")
  183. parser.add_argument("--img_cls_file", help="path to file with classes ids for images, val.txt file from this "
  184. "archive: http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz")
  185. parser.add_argument("--prototxt", help="path to caffe prototxt, download it here: "
  186. "https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/deploy.prototxt")
  187. parser.add_argument("--caffemodel", help="path to caffemodel file, download it here: "
  188. "http://dl.caffe.berkeleyvision.org/bvlc_alexnet.caffemodel")
  189. parser.add_argument("--log", help="path to logging file")
  190. parser.add_argument("--mean", help="path to ImageNet mean blob caffe file, imagenet_mean.binaryproto file from"
  191. "this archive: http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz")
  192. parser.add_argument("--batch_size", help="size of images in batch", default=1000)
  193. parser.add_argument("--frame_size", help="size of input image", default=227)
  194. parser.add_argument("--in_blob", help="name for input blob", default='data')
  195. parser.add_argument("--out_blob", help="name for output blob", default='prob')
  196. args = parser.parse_args()
  197. data_fetcher = MeanBlobFetch(args.frame_size, args.mean, args.imgs_dir)
  198. frameworks = [CaffeModel(args.prototxt, args.caffemodel, args.in_blob, args.out_blob),
  199. DnnCaffeModel(args.prototxt, args.caffemodel, '', args.out_blob)]
  200. acc_eval = ClsAccEvaluation(args.log, args.img_cls_file, args.batch_size)
  201. acc_eval.process(frameworks, data_fetcher)