gpc_train_sintel.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import argparse
  2. import glob
  3. import os
  4. import subprocess
  5. FRAME_DIST = 2
  6. assert (FRAME_DIST >= 1)
  7. def execute(cmd):
  8. popen = subprocess.Popen(cmd,
  9. stdout=subprocess.PIPE,
  10. stderr=subprocess.PIPE)
  11. for stdout_line in iter(popen.stdout.readline, ''):
  12. print(stdout_line.rstrip())
  13. for stderr_line in iter(popen.stderr.readline, ''):
  14. print(stderr_line.rstrip())
  15. popen.stdout.close()
  16. popen.stderr.close()
  17. return_code = popen.wait()
  18. if return_code != 0:
  19. raise subprocess.CalledProcessError(return_code, cmd)
  20. def main():
  21. parser = argparse.ArgumentParser(
  22. description='Train Global Patch Collider using MPI Sintel dataset')
  23. parser.add_argument(
  24. '--bin_path',
  25. help='Path to the training executable (example_optflow_gpc_train)',
  26. required=True)
  27. parser.add_argument('--dataset_path',
  28. help='Path to the directory with frames',
  29. required=True)
  30. parser.add_argument('--gt_path',
  31. help='Path to the directory with ground truth flow',
  32. required=True)
  33. parser.add_argument('--descriptor_type',
  34. help='Descriptor type',
  35. type=int,
  36. default=0)
  37. args = parser.parse_args()
  38. seq = glob.glob(os.path.join(args.dataset_path, '*'))
  39. seq.sort()
  40. input_files = []
  41. for s in seq:
  42. seq_name = os.path.basename(s)
  43. frames = glob.glob(os.path.join(s, 'frame*.png'))
  44. frames.sort()
  45. for i in range(0, len(frames) - 1, FRAME_DIST):
  46. gt_flow = os.path.join(args.gt_path, seq_name,
  47. os.path.basename(frames[i])[0:-4] + '.flo')
  48. assert (os.path.isfile(gt_flow))
  49. input_files += [frames[i], frames[i + 1], gt_flow]
  50. execute([args.bin_path, '--descriptor-type=%d' % args.descriptor_type] + input_files)
  51. if __name__ == '__main__':
  52. main()