dnn_superres_multioutput.cpp 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include <iostream>
  5. #include <sstream>
  6. #include <opencv2/dnn_superres.hpp>
  7. #include <opencv2/imgproc.hpp>
  8. #include <opencv2/highgui.hpp>
  9. using namespace std;
  10. using namespace cv;
  11. using namespace dnn_superres;
  12. int main(int argc, char *argv[])
  13. {
  14. // Check for valid command line arguments, print usage
  15. // if insufficient arguments were given.
  16. if (argc < 4) {
  17. cout << "usage: Arg 1: image | Path to image" << endl;
  18. cout << "\t Arg 2: scales in a format of 2,4,8\n";
  19. cout << "\t Arg 3: output node names in a format of nchw_output_0,nchw_output_1\n";
  20. cout << "\t Arg 4: path to model file \n";
  21. return -1;
  22. }
  23. string img_path = string(argv[1]);
  24. string scales_str = string(argv[2]);
  25. string output_names_str = string(argv[3]);
  26. std::string path = string(argv[4]);
  27. //Parse the scaling factors
  28. std::vector<int> scales;
  29. char delim = ',';
  30. {
  31. std::stringstream ss(scales_str);
  32. std::string token;
  33. while (std::getline(ss, token, delim)) {
  34. scales.push_back(atoi(token.c_str()));
  35. }
  36. }
  37. //Parse the output node names
  38. std::vector<String> node_names;
  39. {
  40. std::stringstream ss(output_names_str);
  41. std::string token;
  42. while (std::getline(ss, token, delim)) {
  43. node_names.push_back(token);
  44. }
  45. }
  46. // Load the image
  47. Mat img = cv::imread(img_path);
  48. Mat original_img(img);
  49. if (img.empty())
  50. {
  51. std::cerr << "Couldn't load image: " << img << "\n";
  52. return -2;
  53. }
  54. //Make dnn super resolution instance
  55. DnnSuperResImpl sr;
  56. int scale = *max_element(scales.begin(), scales.end());
  57. std::vector<Mat> outputs;
  58. sr.readModel(path);
  59. sr.setModel("lapsrn", scale);
  60. sr.upsampleMultioutput(img, outputs, scales, node_names);
  61. for(unsigned int i = 0; i < outputs.size(); i++)
  62. {
  63. cv::namedWindow("Upsampled image", WINDOW_AUTOSIZE);
  64. cv::imshow("Upsampled image", outputs[i]);
  65. //cv::imwrite("./saved.jpg", img_new);
  66. cv::waitKey(0);
  67. }
  68. return 0;
  69. }