colorization.cpp 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html
  4. #include <opencv2/dnn.hpp>
  5. #include <opencv2/imgproc.hpp>
  6. #include <opencv2/highgui.hpp>
  7. #include <iostream>
  8. using namespace cv;
  9. using namespace cv::dnn;
  10. using namespace std;
  11. // the 313 ab cluster centers from pts_in_hull.npy (already transposed)
  12. static float hull_pts[] = {
  13. -90., -90., -90., -90., -90., -80., -80., -80., -80., -80., -80., -80., -80., -70., -70., -70., -70., -70., -70., -70., -70.,
  14. -70., -70., -60., -60., -60., -60., -60., -60., -60., -60., -60., -60., -60., -60., -50., -50., -50., -50., -50., -50., -50., -50.,
  15. -50., -50., -50., -50., -50., -50., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -30.,
  16. -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -20., -20., -20., -20., -20., -20., -20.,
  17. -20., -20., -20., -20., -20., -20., -20., -20., -20., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
  18. -10., -10., -10., -10., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 10., 10., 10., 10., 10., 10., 10.,
  19. 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20.,
  20. 20., 20., 20., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 40., 40., 40., 40.,
  21. 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50.,
  22. 50., 50., 50., 50., 50., 50., 50., 50., 50., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60.,
  23. 60., 60., 60., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 80., 80., 80.,
  24. 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 90., 90., 90., 90., 90., 90., 90., 90., 90., 90.,
  25. 90., 90., 90., 90., 90., 90., 90., 90., 90., 100., 100., 100., 100., 100., 100., 100., 100., 100., 100., 50., 60., 70., 80., 90.,
  26. 20., 30., 40., 50., 60., 70., 80., 90., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -20., -10., 0., 10., 20., 30., 40., 50.,
  27. 60., 70., 80., 90., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., -40., -30., -20., -10., 0., 10., 20.,
  28. 30., 40., 50., 60., 70., 80., 90., 100., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., -50.,
  29. -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., -60., -50., -40., -30., -20., -10., 0., 10., 20.,
  30. 30., 40., 50., 60., 70., 80., 90., 100., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90.,
  31. 100., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -80., -70., -60., -50.,
  32. -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -90., -80., -70., -60., -50., -40., -30., -20., -10.,
  33. 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -100., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30.,
  34. 40., 50., 60., 70., 80., 90., -100., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70.,
  35. 80., -110., -100., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., -110., -100.,
  36. -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., -110., -100., -90., -80., -70.,
  37. -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., -110., -100., -90., -80., -70., -60., -50., -40., -30.,
  38. -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0.
  39. };
  40. int main(int argc, char **argv)
  41. {
  42. const string about =
  43. "This sample demonstrates recoloring grayscale images with dnn.\n"
  44. "This program is based on:\n"
  45. " http://richzhang.github.io/colorization\n"
  46. " https://github.com/richzhang/colorization\n"
  47. "Download caffemodel and prototxt files:\n"
  48. " http://eecs.berkeley.edu/~rich.zhang/projects/2016_colorization/files/demo_v2/colorization_release_v2.caffemodel\n"
  49. " https://raw.githubusercontent.com/richzhang/colorization/caffe/models/colorization_deploy_v2.prototxt\n";
  50. const string keys =
  51. "{ h help | | print this help message }"
  52. "{ proto | colorization_deploy_v2.prototxt | model configuration }"
  53. "{ model | colorization_release_v2.caffemodel | model weights }"
  54. "{ image | space_shuttle.jpg | path to image file }"
  55. "{ opencl | | enable OpenCL }";
  56. CommandLineParser parser(argc, argv, keys);
  57. parser.about(about);
  58. if (parser.has("help"))
  59. {
  60. parser.printMessage();
  61. return 0;
  62. }
  63. string modelTxt = samples::findFile(parser.get<string>("proto"));
  64. string modelBin = samples::findFile(parser.get<string>("model"));
  65. string imageFile = samples::findFile(parser.get<string>("image"));
  66. bool useOpenCL = parser.has("opencl");
  67. if (!parser.check())
  68. {
  69. parser.printErrors();
  70. return 1;
  71. }
  72. Mat img = imread(imageFile);
  73. if (img.empty())
  74. {
  75. cout << "Can't read image from file: " << imageFile << endl;
  76. return 2;
  77. }
  78. // fixed input size for the pretrained network
  79. const int W_in = 224;
  80. const int H_in = 224;
  81. Net net = dnn::readNetFromCaffe(modelTxt, modelBin);
  82. if (useOpenCL)
  83. net.setPreferableTarget(DNN_TARGET_OPENCL);
  84. // setup additional layers:
  85. int sz[] = {2, 313, 1, 1};
  86. const Mat pts_in_hull(4, sz, CV_32F, hull_pts);
  87. Ptr<dnn::Layer> class8_ab = net.getLayer("class8_ab");
  88. class8_ab->blobs.push_back(pts_in_hull);
  89. Ptr<dnn::Layer> conv8_313_rh = net.getLayer("conv8_313_rh");
  90. conv8_313_rh->blobs.push_back(Mat(1, 313, CV_32F, Scalar(2.606)));
  91. // extract L channel and subtract mean
  92. Mat lab, L, input;
  93. img.convertTo(img, CV_32F, 1.0/255);
  94. cvtColor(img, lab, COLOR_BGR2Lab);
  95. extractChannel(lab, L, 0);
  96. resize(L, input, Size(W_in, H_in));
  97. input -= 50;
  98. // run the L channel through the network
  99. Mat inputBlob = blobFromImage(input);
  100. net.setInput(inputBlob);
  101. Mat result = net.forward();
  102. // retrieve the calculated a,b channels from the network output
  103. Size siz(result.size[2], result.size[3]);
  104. Mat a = Mat(siz, CV_32F, result.ptr(0,0));
  105. Mat b = Mat(siz, CV_32F, result.ptr(0,1));
  106. resize(a, a, img.size());
  107. resize(b, b, img.size());
  108. // merge, and convert back to BGR
  109. Mat color, chn[] = {L, a, b};
  110. merge(chn, 3, lab);
  111. cvtColor(lab, color, COLOR_Lab2BGR);
  112. imshow("color", color);
  113. imshow("original", img);
  114. waitKey();
  115. return 0;
  116. }