gpc_evaluate.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. #include "opencv2/core/ocl.hpp"
  2. #include "opencv2/highgui.hpp"
  3. #include "opencv2/imgcodecs.hpp"
  4. #include "opencv2/optflow.hpp"
  5. #include <fstream>
  6. #include <iostream>
  7. #include <stdio.h>
  8. /* This tool finds correspondences between two images using Global Patch Collider
  9. * and calculates error using provided ground truth flow.
  10. *
  11. * It will look for the file named "forest.yml.gz" with a learned forest.
  12. * You can obtain the "forest.yml.gz" either by manually training it using another tool with *_train suffix
  13. * or by downloading one of the files trained on some publicly available dataset from here:
  14. *
  15. * https://drive.google.com/open?id=0B7Hb8cfuzrIIZDFscXVYd0NBNFU
  16. */
  17. using namespace cv;
  18. const String keys = "{help h ? | | print this message}"
  19. "{@image1 |<none> | image1}"
  20. "{@image2 |<none> | image2}"
  21. "{@groundtruth |<none> | path to the .flo file}"
  22. "{@output | | output to a file instead of displaying, output image path}"
  23. "{g gpu | | use OpenCL}"
  24. "{f forest |forest.yml.gz| path to the forest.yml.gz}";
  25. const int nTrees = 5;
  26. static double normL2( const Point2f &v ) { return sqrt( v.x * v.x + v.y * v.y ); }
  27. static Vec3d getFlowColor( const Point2f &f, const bool logScale = true, const double scaleDown = 5 )
  28. {
  29. if ( f.x == 0 && f.y == 0 )
  30. return Vec3d( 0, 0, 1 );
  31. double radius = normL2( f );
  32. if ( logScale )
  33. radius = log( radius + 1 );
  34. radius /= scaleDown;
  35. radius = std::min( 1.0, radius );
  36. double angle = ( atan2( -f.y, -f.x ) + CV_PI ) * 180 / CV_PI;
  37. return Vec3d( angle, radius, 1 );
  38. }
  39. static void displayFlow( InputArray _flow, OutputArray _img )
  40. {
  41. const Size sz = _flow.size();
  42. Mat flow = _flow.getMat();
  43. _img.create( sz, CV_32FC3 );
  44. Mat img = _img.getMat();
  45. for ( int i = 0; i < sz.height; ++i )
  46. for ( int j = 0; j < sz.width; ++j )
  47. img.at< Vec3f >( i, j ) = getFlowColor( flow.at< Point2f >( i, j ) );
  48. cvtColor( img, img, COLOR_HSV2BGR );
  49. }
  50. static bool fileProbe( const char *name ) { return std::ifstream( name ).good(); }
  51. int main( int argc, const char **argv )
  52. {
  53. CommandLineParser parser( argc, argv, keys );
  54. parser.about( "Global Patch Collider evaluation tool" );
  55. if ( parser.has( "help" ) )
  56. {
  57. parser.printMessage();
  58. return 0;
  59. }
  60. String fromPath = parser.get< String >( 0 );
  61. String toPath = parser.get< String >( 1 );
  62. String gtPath = parser.get< String >( 2 );
  63. String outPath = parser.get< String >( 3 );
  64. const bool useOpenCL = parser.has( "gpu" );
  65. String forestDumpPath = parser.get< String >( "forest" );
  66. if ( !parser.check() )
  67. {
  68. parser.printErrors();
  69. return 1;
  70. }
  71. if ( !fileProbe( forestDumpPath.c_str() ) )
  72. {
  73. std::cerr << "Can't open the file with a trained model: `" << forestDumpPath
  74. << "`.\nYou can obtain this file either by manually training the model using another tool with *_train suffix or by "
  75. "downloading one of the files trained on some publicly available dataset from "
  76. "here:\nhttps://drive.google.com/open?id=0B7Hb8cfuzrIIZDFscXVYd0NBNFU"
  77. << std::endl;
  78. return 1;
  79. }
  80. ocl::setUseOpenCL( useOpenCL );
  81. Ptr< optflow::GPCForest< nTrees > > forest = Algorithm::load< optflow::GPCForest< nTrees > >( forestDumpPath );
  82. Mat from = imread( fromPath );
  83. Mat to = imread( toPath );
  84. Mat gt = readOpticalFlow( gtPath );
  85. std::vector< std::pair< Point2i, Point2i > > corr;
  86. TickMeter meter;
  87. meter.start();
  88. forest->findCorrespondences( from, to, corr, optflow::GPCMatchingParams( useOpenCL ) );
  89. meter.stop();
  90. std::cout << "Found " << corr.size() << " matches." << std::endl;
  91. std::cout << "Time: " << meter.getTimeSec() << " sec." << std::endl;
  92. double error = 0;
  93. int totalCorrectFlowVectors = 0;
  94. Mat dispErr = Mat::zeros( from.size(), CV_32FC3 );
  95. dispErr = Scalar( 0, 0, 1 );
  96. Mat disp = Mat::zeros( from.size(), CV_32FC3 );
  97. disp = Scalar( 0, 0, 1 );
  98. for ( size_t i = 0; i < corr.size(); ++i )
  99. {
  100. const Point2f a = corr[i].first;
  101. const Point2f b = corr[i].second;
  102. const Point2f gtDisplacement = gt.at< Point2f >( corr[i].first.y, corr[i].first.x );
  103. // Check that flow vector is correct
  104. if (!cvIsNaN(gtDisplacement.x) && !cvIsNaN(gtDisplacement.y) && gtDisplacement.x < 1e9 && gtDisplacement.y < 1e9)
  105. {
  106. const Point2f c = a + gtDisplacement;
  107. error += normL2( b - c );
  108. circle( dispErr, a, 3, getFlowColor( b - c, false, 32 ), -1 );
  109. ++totalCorrectFlowVectors;
  110. }
  111. circle( disp, a, 3, getFlowColor( b - a ), -1 );
  112. }
  113. if (totalCorrectFlowVectors)
  114. error /= totalCorrectFlowVectors;
  115. std::cout << "Average endpoint error: " << error << " px." << std::endl;
  116. cvtColor( disp, disp, COLOR_HSV2BGR );
  117. cvtColor( dispErr, dispErr, COLOR_HSV2BGR );
  118. Mat dispGroundTruth;
  119. displayFlow( gt, dispGroundTruth );
  120. if ( outPath.length() )
  121. {
  122. putText( disp, "Sparse matching: Global Patch Collider", Point2i( 24, 40 ), FONT_HERSHEY_DUPLEX, 1, Vec3b( 1, 0, 0 ), 2, LINE_AA );
  123. char buf[256];
  124. sprintf( buf, "Average EPE: %.2f", error );
  125. putText( disp, buf, Point2i( 24, 80 ), FONT_HERSHEY_DUPLEX, 1, Vec3b( 1, 0, 0 ), 2, LINE_AA );
  126. sprintf( buf, "Number of matches: %u", (unsigned)corr.size() );
  127. putText( disp, buf, Point2i( 24, 120 ), FONT_HERSHEY_DUPLEX, 1, Vec3b( 1, 0, 0 ), 2, LINE_AA );
  128. disp *= 255;
  129. imwrite( outPath, disp );
  130. return 0;
  131. }
  132. namedWindow( "Correspondences", WINDOW_AUTOSIZE );
  133. imshow( "Correspondences", disp );
  134. namedWindow( "Error", WINDOW_AUTOSIZE );
  135. imshow( "Error", dispErr );
  136. namedWindow( "Ground truth", WINDOW_AUTOSIZE );
  137. imshow( "Ground truth", dispGroundTruth );
  138. waitKey( 0 );
  139. return 0;
  140. }