gpc_train_middlebury.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import argparse
  2. import glob
  3. import os
  4. import subprocess
  5. def execute(cmd):
  6. popen = subprocess.Popen(cmd,
  7. stdout=subprocess.PIPE,
  8. stderr=subprocess.PIPE)
  9. for stdout_line in iter(popen.stdout.readline, ''):
  10. print(stdout_line.rstrip())
  11. for stderr_line in iter(popen.stderr.readline, ''):
  12. print(stderr_line.rstrip())
  13. popen.stdout.close()
  14. popen.stderr.close()
  15. return_code = popen.wait()
  16. if return_code != 0:
  17. raise subprocess.CalledProcessError(return_code, cmd)
  18. def main():
  19. parser = argparse.ArgumentParser(
  20. description='Train Global Patch Collider using Middlebury dataset')
  21. parser.add_argument(
  22. '--bin_path',
  23. help='Path to the training executable (example_optflow_gpc_train)',
  24. required=True)
  25. parser.add_argument('--dataset_path',
  26. help='Path to the directory with frames',
  27. required=True)
  28. parser.add_argument('--gt_path',
  29. help='Path to the directory with ground truth flow',
  30. required=True)
  31. parser.add_argument('--descriptor_type',
  32. help='Descriptor type',
  33. type=int,
  34. default=0)
  35. args = parser.parse_args()
  36. seq = glob.glob(os.path.join(args.dataset_path, '*'))
  37. seq.sort()
  38. input_files = []
  39. for s in seq:
  40. if os.path.isdir(s):
  41. seq_name = os.path.basename(s)
  42. frames = glob.glob(os.path.join(s, 'frame*.png'))
  43. frames.sort()
  44. assert (len(frames) == 2)
  45. assert (os.path.basename(frames[0]) == 'frame10.png')
  46. assert (os.path.basename(frames[1]) == 'frame11.png')
  47. gt_flow = os.path.join(args.gt_path, seq_name, 'flow10.flo')
  48. if os.path.isfile(gt_flow):
  49. input_files += [frames[0], frames[1], gt_flow]
  50. execute([args.bin_path, '--descriptor-type=%d' % args.descriptor_type] + input_files)
  51. if __name__ == '__main__':
  52. main()