gpc_train.cpp 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. #include "opencv2/optflow.hpp"
  2. #include <iostream>
  3. /* This tool trains the forest for the Global Patch Collider and stores output to the "forest.yml.gz".
  4. */
  5. using namespace cv;
  6. const String keys = "{help h ? | | print this message}"
  7. "{max-tree-depth | | Maximum tree depth to stop partitioning}"
  8. "{min-samples | | Minimum number of samples in the node to stop partitioning}"
  9. "{descriptor-type|0 | Descriptor type. Set to 0 for quality, 1 for speed.}"
  10. "{print-progress | | Set to 0 to enable quiet mode, set to 1 to print progress}"
  11. "{f forest |forest.yml.gz| Path where to store resulting forest. It is recommended to use .yml.gz extension.}";
  12. const int nTrees = 5;
  13. static void fillInputImagesFromCommandLine( std::vector< String > &img1, std::vector< String > &img2, std::vector< String > &gt, int argc,
  14. const char **argv )
  15. {
  16. for ( int i = 1, j = 0; i < argc; ++i )
  17. {
  18. if ( argv[i][0] == '-' )
  19. continue;
  20. if ( j % 3 == 0 )
  21. img1.push_back( argv[i] );
  22. if ( j % 3 == 1 )
  23. img2.push_back( argv[i] );
  24. if ( j % 3 == 2 )
  25. gt.push_back( argv[i] );
  26. ++j;
  27. }
  28. }
  29. int main( int argc, const char **argv )
  30. {
  31. CommandLineParser parser( argc, argv, keys );
  32. parser.about( "Global Patch Collider training tool" );
  33. std::vector< String > img1, img2, gt;
  34. optflow::GPCTrainingParams params;
  35. if ( parser.has( "max-tree-depth" ) )
  36. params.maxTreeDepth = parser.get< unsigned >( "max-tree-depth" );
  37. if ( parser.has( "min-samples" ) )
  38. params.minNumberOfSamples = parser.get< unsigned >( "min-samples" );
  39. if ( parser.has( "descriptor-type" ) )
  40. params.descriptorType = parser.get< int >( "descriptor-type" );
  41. if ( parser.has( "print-progress" ) )
  42. params.printProgress = parser.get< unsigned >( "print-progress" ) != 0;
  43. fillInputImagesFromCommandLine( img1, img2, gt, argc, argv );
  44. if ( parser.has( "help" ) || img1.size() != img2.size() || img1.size() != gt.size() || img1.size() == 0 )
  45. {
  46. std::cerr << "\nUsage: " << argv[0] << " [params] ImageFrom1 ImageTo1 GroundTruth1 ... ImageFromN ImageToN GroundTruthN\n" << std::endl;
  47. parser.printMessage();
  48. return 1;
  49. }
  50. Ptr< optflow::GPCForest< nTrees > > forest = optflow::GPCForest< nTrees >::create();
  51. forest->train( img1, img2, gt, params );
  52. forest->save( parser.get< String >( "forest" ) );
  53. return 0;
  54. }