imagenet_cls_test_inception.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import numpy as np
  2. import sys
  3. import os
  4. import argparse
  5. import tensorflow as tf
  6. from tensorflow.python.platform import gfile
  7. from imagenet_cls_test_alexnet import MeanValueFetch, DnnCaffeModel, Framework, ClsAccEvaluation
  8. try:
  9. import cv2 as cv
  10. except ImportError:
  11. raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '
  12. 'configure environment variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')
  13. # If you've got an exception "Cannot load libmkl_avx.so or libmkl_def.so" or similar, try to export next variable
  14. # before running the script:
  15. # LD_PRELOAD=/opt/intel/mkl/lib/intel64/libmkl_core.so:/opt/intel/mkl/lib/intel64/libmkl_sequential.so
  16. class TensorflowModel(Framework):
  17. sess = tf.Session
  18. output = tf.Graph
  19. def __init__(self, model_file, in_blob_name, out_blob_name):
  20. self.in_blob_name = in_blob_name
  21. self.sess = tf.Session()
  22. with gfile.FastGFile(model_file, 'rb') as f:
  23. graph_def = tf.GraphDef()
  24. graph_def.ParseFromString(f.read())
  25. self.sess.graph.as_default()
  26. tf.import_graph_def(graph_def, name='')
  27. self.output = self.sess.graph.get_tensor_by_name(out_blob_name + ":0")
  28. def get_name(self):
  29. return 'Tensorflow'
  30. def get_output(self, input_blob):
  31. assert len(input_blob.shape) == 4
  32. batch_tf = input_blob.transpose(0, 2, 3, 1)
  33. out = self.sess.run(self.output,
  34. {self.in_blob_name+':0': batch_tf})
  35. out = out[..., 1:1001]
  36. return out
  37. class DnnTfInceptionModel(DnnCaffeModel):
  38. net = cv.dnn.Net()
  39. def __init__(self, model_file, in_blob_name, out_blob_name):
  40. self.net = cv.dnn.readNetFromTensorflow(model_file)
  41. self.in_blob_name = in_blob_name
  42. self.out_blob_name = out_blob_name
  43. def get_output(self, input_blob):
  44. return super(DnnTfInceptionModel, self).get_output(input_blob)[..., 1:1001]
  45. if __name__ == "__main__":
  46. parser = argparse.ArgumentParser()
  47. parser.add_argument("--imgs_dir", help="path to ImageNet validation subset images dir, ILSVRC2012_img_val dir")
  48. parser.add_argument("--img_cls_file", help="path to file with classes ids for images, download it here:"
  49. "https://github.com/opencv/opencv_extra/tree/4.x/testdata/dnn/img_classes_inception.txt")
  50. parser.add_argument("--model", help="path to tensorflow model, download it here:"
  51. "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip")
  52. parser.add_argument("--log", help="path to logging file")
  53. parser.add_argument("--batch_size", help="size of images in batch", default=1)
  54. parser.add_argument("--frame_size", help="size of input image", default=224)
  55. parser.add_argument("--in_blob", help="name for input blob", default='input')
  56. parser.add_argument("--out_blob", help="name for output blob", default='softmax2')
  57. args = parser.parse_args()
  58. data_fetcher = MeanValueFetch(args.frame_size, args.imgs_dir, True)
  59. frameworks = [TensorflowModel(args.model, args.in_blob, args.out_blob),
  60. DnnTfInceptionModel(args.model, '', args.out_blob)]
  61. acc_eval = ClsAccEvaluation(args.log, args.img_cls_file, args.batch_size)
  62. acc_eval.process(frameworks, data_fetcher)