dnn_superres_benchmark_quality.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include <iostream>
  5. #include <opencv2/opencv_modules.hpp>
  6. #ifdef HAVE_OPENCV_QUALITY
  7. #include <opencv2/dnn_superres.hpp>
  8. #include <opencv2/quality.hpp>
  9. #include <opencv2/imgproc.hpp>
  10. #include <opencv2/highgui.hpp>
  11. using namespace std;
  12. using namespace cv;
  13. using namespace dnn_superres;
  14. static void showBenchmark(vector<Mat> images, string title, Size imageSize,
  15. const vector<String> imageTitles,
  16. const vector<double> psnrValues,
  17. const vector<double> ssimValues)
  18. {
  19. int fontFace = FONT_HERSHEY_COMPLEX_SMALL;
  20. int fontScale = 1;
  21. Scalar fontColor = Scalar(255, 255, 255);
  22. int len = static_cast<int>(images.size());
  23. int cols = 2, rows = 2;
  24. Mat fullImage = Mat::zeros(Size((cols * 10) + imageSize.width * cols, (rows * 10) + imageSize.height * rows),
  25. images[0].type());
  26. stringstream ss;
  27. int h_ = -1;
  28. for (int i = 0; i < len; i++) {
  29. int fontStart = 15;
  30. int w_ = i % cols;
  31. if (i % cols == 0)
  32. h_++;
  33. Rect ROI((w_ * (10 + imageSize.width)), (h_ * (10 + imageSize.height)), imageSize.width, imageSize.height);
  34. Mat tmp;
  35. resize(images[i], tmp, Size(ROI.width, ROI.height));
  36. ss << imageTitles[i];
  37. putText(tmp,
  38. ss.str(),
  39. Point(5, fontStart),
  40. fontFace,
  41. fontScale,
  42. fontColor,
  43. 1,
  44. 16);
  45. ss.str("");
  46. fontStart += 20;
  47. ss << "PSNR: " << psnrValues[i];
  48. putText(tmp,
  49. ss.str(),
  50. Point(5, fontStart),
  51. fontFace,
  52. fontScale,
  53. fontColor,
  54. 1,
  55. 16);
  56. ss.str("");
  57. fontStart += 20;
  58. ss << "SSIM: " << ssimValues[i];
  59. putText(tmp,
  60. ss.str(),
  61. Point(5, fontStart),
  62. fontFace,
  63. fontScale,
  64. fontColor,
  65. 1,
  66. 16);
  67. ss.str("");
  68. fontStart += 20;
  69. tmp.copyTo(fullImage(ROI));
  70. }
  71. namedWindow(title, 1);
  72. imshow(title, fullImage);
  73. waitKey();
  74. }
  75. static Vec2d getQualityValues(Mat orig, Mat upsampled)
  76. {
  77. double psnr = PSNR(upsampled, orig);
  78. Scalar q = quality::QualitySSIM::compute(upsampled, orig, noArray());
  79. double ssim = mean(Vec3d((q[0]), q[1], q[2]))[0];
  80. return Vec2d(psnr, ssim);
  81. }
  82. int main(int argc, char *argv[])
  83. {
  84. // Check for valid command line arguments, print usage
  85. // if insufficient arguments were given.
  86. if (argc < 4) {
  87. cout << "usage: Arg 1: image path | Path to image" << endl;
  88. cout << "\t Arg 2: algorithm | edsr, espcn, fsrcnn or lapsrn" << endl;
  89. cout << "\t Arg 3: path to model file 2 \n";
  90. cout << "\t Arg 4: scale | 2, 3, 4 or 8 \n";
  91. return -1;
  92. }
  93. string path = string(argv[1]);
  94. string algorithm = string(argv[2]);
  95. string model = string(argv[3]);
  96. int scale = atoi(argv[4]);
  97. Mat img = imread(path);
  98. if (img.empty()) {
  99. cerr << "Couldn't load image: " << img << "\n";
  100. return -2;
  101. }
  102. //Crop the image so the images will be aligned
  103. int width = img.cols - (img.cols % scale);
  104. int height = img.rows - (img.rows % scale);
  105. Mat cropped = img(Rect(0, 0, width, height));
  106. //Downscale the image for benchmarking
  107. Mat img_downscaled;
  108. resize(cropped, img_downscaled, Size(), 1.0 / scale, 1.0 / scale);
  109. //Make dnn super resolution instance
  110. DnnSuperResImpl sr;
  111. vector <Mat> allImages;
  112. Mat img_new;
  113. //Read and set the dnn model
  114. sr.readModel(model);
  115. sr.setModel(algorithm, scale);
  116. sr.upsample(img_downscaled, img_new);
  117. vector<double> psnrValues = vector<double>();
  118. vector<double> ssimValues = vector<double>();
  119. //DL MODEL
  120. Vec2f quality = getQualityValues(cropped, img_new);
  121. psnrValues.push_back(quality[0]);
  122. ssimValues.push_back(quality[1]);
  123. cout << sr.getAlgorithm() << ":" << endl;
  124. cout << "PSNR: " << quality[0] << " SSIM: " << quality[1] << endl;
  125. cout << "----------------------" << endl;
  126. //BICUBIC
  127. Mat bicubic;
  128. resize(img_downscaled, bicubic, Size(), scale, scale, INTER_CUBIC);
  129. quality = getQualityValues(cropped, bicubic);
  130. psnrValues.push_back(quality[0]);
  131. ssimValues.push_back(quality[1]);
  132. cout << "Bicubic " << endl;
  133. cout << "PSNR: " << quality[0] << " SSIM: " << quality[1] << endl;
  134. cout << "----------------------" << endl;
  135. //NEAREST NEIGHBOR
  136. Mat nearest;
  137. resize(img_downscaled, nearest, Size(), scale, scale, INTER_NEAREST);
  138. quality = getQualityValues(cropped, nearest);
  139. psnrValues.push_back(quality[0]);
  140. ssimValues.push_back(quality[1]);
  141. cout << "Nearest neighbor" << endl;
  142. cout << "PSNR: " << quality[0] << " SSIM: " << quality[1] << endl;
  143. cout << "----------------------" << endl;
  144. //LANCZOS
  145. Mat lanczos;
  146. resize(img_downscaled, lanczos, Size(), scale, scale, INTER_LANCZOS4);
  147. quality = getQualityValues(cropped, lanczos);
  148. psnrValues.push_back(quality[0]);
  149. ssimValues.push_back(quality[1]);
  150. cout << "Lanczos" << endl;
  151. cout << "PSNR: " << quality[0] << " SSIM: " << quality[1] << endl;
  152. cout << "-----------------------------------------------" << endl;
  153. vector <Mat> imgs{img_new, bicubic, nearest, lanczos};
  154. vector <String> titles{sr.getAlgorithm(), "Bicubic", "Nearest neighbor", "Lanczos"};
  155. showBenchmark(imgs, "Quality benchmark", Size(bicubic.cols, bicubic.rows), titles, psnrValues, ssimValues);
  156. waitKey(0);
  157. return 0;
  158. }
  159. #else
  160. int main()
  161. {
  162. std::cout << "This sample requires the OpenCV Quality module." << std::endl;
  163. return 0;
  164. }
  165. #endif