optical_flow_benchmark.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. #!/usr/bin/env python
  2. from __future__ import print_function
  3. import os, sys, shutil
  4. import argparse
  5. import json, re
  6. from subprocess import check_output
  7. import datetime
  8. import matplotlib.pyplot as plt
  9. def load_json(path):
  10. f = open(path, "r")
  11. data = json.load(f)
  12. return data
  13. def save_json(obj, path):
  14. tmp_file = path + ".bak"
  15. f = open(tmp_file, "w")
  16. json.dump(obj, f, indent=2)
  17. f.flush()
  18. os.fsync(f.fileno())
  19. f.close()
  20. try:
  21. os.rename(tmp_file, path)
  22. except:
  23. os.remove(path)
  24. os.rename(tmp_file, path)
  25. def parse_evaluation_result(input_str, i):
  26. res = {}
  27. res['frame_number'] = i + 1
  28. res['error'] = {}
  29. regex = "([A-Za-z. \\[\\].0-9]+):[ ]*([0-9]*\.[0-9]+|[0-9]+)"
  30. for elem in re.findall(regex,input_str):
  31. if "Time" in elem[0]:
  32. res['time'] = float(elem[1])
  33. elif "Average" in elem[0]:
  34. res['error']['average'] = float(elem[1])
  35. elif "deviation" in elem[0]:
  36. res['error']['std'] = float(elem[1])
  37. else:
  38. res['error'][elem[0]] = float(elem[1])
  39. return res
  40. def evaluate_sequence(sequence, algorithm, dataset, executable, img_files, gt_files,
  41. state, state_path):
  42. if "eval_results" not in state[dataset][algorithm][-1].keys():
  43. state[dataset][algorithm][-1]["eval_results"] = {}
  44. elif sequence in state[dataset][algorithm][-1]["eval_results"].keys():
  45. return
  46. res = []
  47. for i in range(len(img_files) - 1):
  48. sys.stdout.write("Algorithm: %-20s Sequence: %-10s Done: [%3d/%3d]\r" %
  49. (algorithm, sequence, i, len(img_files) - 1)),
  50. sys.stdout.flush()
  51. res_string = check_output([executable, img_files[i], img_files[i + 1],
  52. algorithm, gt_files[i]])
  53. res.append(parse_evaluation_result(res_string, i))
  54. state[dataset][algorithm][-1]["eval_results"][sequence] = res
  55. save_json(state, state_path)
  56. #############################DATSET DEFINITIONS################################
  57. def evaluate_mpi_sintel(source_dir, algorithm, evaluation_executable, state, state_path):
  58. evaluation_result = {}
  59. img_dir = os.path.join(source_dir, 'mpi_sintel', 'training', 'final')
  60. gt_dir = os.path.join(source_dir, 'mpi_sintel', 'training', 'flow')
  61. sequences = [f for f in os.listdir(img_dir)
  62. if os.path.isdir(os.path.join(img_dir, f))]
  63. for seq in sequences:
  64. img_files = sorted([os.path.join(img_dir, seq, f)
  65. for f in os.listdir(os.path.join(img_dir, seq))
  66. if f.endswith(".png")])
  67. gt_files = sorted([os.path.join(gt_dir, seq, f)
  68. for f in os.listdir(os.path.join(gt_dir, seq))
  69. if f.endswith(".flo")])
  70. evaluation_result[seq] = evaluate_sequence(seq, algorithm, 'mpi_sintel',
  71. evaluation_executable, img_files, gt_files, state, state_path)
  72. return evaluation_result
  73. def evaluate_middlebury(source_dir, algorithm, evaluation_executable, state, state_path):
  74. evaluation_result = {}
  75. img_dir = os.path.join(source_dir, 'middlebury', 'other-data')
  76. gt_dir = os.path.join(source_dir, 'middlebury', 'other-gt-flow')
  77. sequences = [f for f in os.listdir(gt_dir)
  78. if os.path.isdir(os.path.join(gt_dir, f))]
  79. for seq in sequences:
  80. img_files = sorted([os.path.join(img_dir, seq, f)
  81. for f in os.listdir(os.path.join(img_dir, seq))
  82. if f.endswith(".png")])
  83. gt_files = sorted([os.path.join(gt_dir, seq, f)
  84. for f in os.listdir(os.path.join(gt_dir, seq))
  85. if f.endswith(".flo")])
  86. evaluation_result[seq] = evaluate_sequence(seq, algorithm, 'middlebury',
  87. evaluation_executable, img_files, gt_files, state, state_path)
  88. return evaluation_result
  89. dataset_eval_functions = {
  90. "mpi_sintel": evaluate_mpi_sintel,
  91. "middlebury": evaluate_middlebury
  92. }
  93. ###############################################################################
  94. def create_dir(dir):
  95. if not os.path.exists(dir):
  96. os.makedirs(dir)
  97. def parse_sequence(input_str):
  98. if len(input_str) == 0:
  99. return []
  100. else:
  101. return [o.strip() for o in input_str.split(",") if o]
  102. def build_chart(dst_folder, state, dataset):
  103. fig = plt.figure(figsize=(16, 10))
  104. markers = ["o", "s", "h", "^", "D"]
  105. marker_idx = 0
  106. colors = ["b", "g", "r"]
  107. color_idx = 0
  108. for algo in state[dataset].keys():
  109. for eval_instance in state[dataset][algo]:
  110. name = algo + "--" + eval_instance["timestamp"]
  111. average_time = 0.0
  112. average_error = 0.0
  113. num_elem = 0
  114. for seq in eval_instance["eval_results"].keys():
  115. for frame in eval_instance["eval_results"][seq]:
  116. average_time += frame["time"]
  117. average_error += frame["error"]["average"]
  118. num_elem += 1
  119. average_time /= num_elem
  120. average_error /= num_elem
  121. marker_style = colors[color_idx] + markers[marker_idx]
  122. color_idx += 1
  123. if color_idx >= len(colors):
  124. color_idx = 0
  125. marker_idx += 1
  126. if marker_idx >= len(markers):
  127. marker_idx = 0
  128. plt.gca().plot([average_time], [average_error],
  129. marker_style,
  130. markersize=14,
  131. label=name)
  132. plt.gca().set_ylabel('Average Endpoint Error (EPE)', fontsize=20)
  133. plt.gca().set_xlabel('Average Runtime (seconds per frame)', fontsize=20)
  134. plt.gca().set_xscale("log")
  135. plt.gca().set_title('Evaluation on ' + dataset, fontsize=20)
  136. plt.gca().legend()
  137. fig.savefig(os.path.join(dst_folder, "evaluation_results_" + dataset + ".png"),
  138. bbox_inches='tight')
  139. plt.close()
  140. if __name__ == '__main__':
  141. parser = argparse.ArgumentParser(
  142. description='Optical flow benchmarking script',
  143. formatter_class=argparse.RawDescriptionHelpFormatter)
  144. parser.add_argument(
  145. "bin_path",
  146. default="./optflow-example-optical_flow_evaluation",
  147. help="Path to the optical flow evaluation executable")
  148. parser.add_argument(
  149. "-a",
  150. "--algorithms",
  151. metavar="ALGORITHMS",
  152. default="",
  153. help=("Comma-separated list of optical-flow algorithms to evaluate "
  154. "(example: -a farneback,tvl1,deepflow). Note that previously "
  155. "evaluated algorithms are also included in the output charts"))
  156. parser.add_argument(
  157. "-d",
  158. "--datasets",
  159. metavar="DATASETS",
  160. default="mpi_sintel",
  161. help=("Comma-separated list of datasets for evaluation (currently only "
  162. "'mpi_sintel' and 'middlebury' are supported)"))
  163. parser.add_argument(
  164. "-f",
  165. "--dataset_folder",
  166. metavar="DATASET_FOLDER",
  167. default="./OF_datasets",
  168. help=("Path to a folder containing datasets. To enable evaluation on "
  169. "MPI Sintel dataset, please download it using the following links: "
  170. "http://files.is.tue.mpg.de/sintel/MPI-Sintel-training_images.zip and "
  171. "http://files.is.tue.mpg.de/sintel/MPI-Sintel-training_extras.zip and "
  172. "unzip these archives into the 'mpi_sintel' folder. To enable evaluation "
  173. "on the Middlebury dataset use the following links: "
  174. "http://vision.middlebury.edu/flow/data/comp/zip/other-color-twoframes.zip, "
  175. "http://vision.middlebury.edu/flow/data/comp/zip/other-gt-flow.zip. "
  176. "These should be unzipped into 'middlebury' folder"))
  177. parser.add_argument(
  178. "-o",
  179. "--out",
  180. metavar="OUT_DIR",
  181. default="./OF_evaluation_results",
  182. help="Output directory where to store benchmark results")
  183. parser.add_argument(
  184. "-s",
  185. "--state",
  186. metavar="STATE_JSON",
  187. default="./OF_evaluation_state.json",
  188. help=("Path to a json file that stores the current evaluation state and "
  189. "previous evaluation results"))
  190. args, other_args = parser.parse_known_args()
  191. if not os.path.isfile(args.bin_path):
  192. print("Error: " + args.bin_path + " does not exist")
  193. sys.exit(1)
  194. if not os.path.exists(args.dataset_folder):
  195. print("Error: " + args.dataset_folder + (" does not exist. Please, correctly "
  196. "specify the -f parameter"))
  197. sys.exit(1)
  198. state = {}
  199. if os.path.isfile(args.state):
  200. state = load_json(args.state)
  201. algorithm_list = parse_sequence(args.algorithms)
  202. dataset_list = parse_sequence(args.datasets)
  203. for dataset in dataset_list:
  204. if dataset not in dataset_eval_functions.keys():
  205. print("Error: unsupported dataset " + dataset)
  206. sys.exit(1)
  207. if dataset not in os.listdir(args.dataset_folder):
  208. print("Error: " + os.path.join(args.dataset_folder, dataset) + (" does not exist. "
  209. "Please, download the dataset and follow the naming conventions "
  210. "(use -h for more information)"))
  211. sys.exit(1)
  212. for dataset in dataset_list:
  213. if dataset not in state.keys():
  214. state[dataset] = {}
  215. for algorithm in algorithm_list:
  216. if algorithm in state[dataset].keys():
  217. last_eval_instance = state[dataset][algorithm][-1]
  218. if "finished" not in last_eval_instance.keys():
  219. print(("Continuing an unfinished evaluation of " +
  220. algorithm + " started at " + last_eval_instance["timestamp"]))
  221. else:
  222. state[dataset][algorithm].append({"timestamp":
  223. datetime.datetime.now().strftime("%Y-%m-%d--%H-%M")})
  224. else:
  225. state[dataset][algorithm] = [{"timestamp":
  226. datetime.datetime.now().strftime("%Y-%m-%d--%H-%M")}]
  227. save_json(state, args.state)
  228. dataset_eval_functions[dataset](args.dataset_folder, algorithm, args.bin_path,
  229. state, args.state)
  230. state[dataset][algorithm][-1]["finished"] = True
  231. save_json(state, args.state)
  232. save_json(state, args.state)
  233. create_dir(args.out)
  234. for dataset in dataset_list:
  235. build_chart(args.out, state, dataset)