opencl_custom_kernel.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html
  4. #include "opencv2/core.hpp"
  5. #include "opencv2/core/ocl.hpp"
  6. #include "opencv2/highgui.hpp"
  7. #include "opencv2/imgcodecs.hpp"
  8. #include "opencv2/imgproc.hpp"
  9. #include <iostream>
  10. using namespace std;
  11. using namespace cv;
  12. static const char* opencl_kernel_src =
  13. "__kernel void magnutude_filter_8u(\n"
  14. " __global const uchar* src, int src_step, int src_offset,\n"
  15. " __global uchar* dst, int dst_step, int dst_offset, int dst_rows, int dst_cols,\n"
  16. " float scale)\n"
  17. "{\n"
  18. " int x = get_global_id(0);\n"
  19. " int y = get_global_id(1);\n"
  20. " if (x < dst_cols && y < dst_rows)\n"
  21. " {\n"
  22. " int dst_idx = y * dst_step + x + dst_offset;\n"
  23. " if (x > 0 && x < dst_cols - 1 && y > 0 && y < dst_rows - 2)\n"
  24. " {\n"
  25. " int src_idx = y * src_step + x + src_offset;\n"
  26. " int dx = (int)src[src_idx]*2 - src[src_idx - 1] - src[src_idx + 1];\n"
  27. " int dy = (int)src[src_idx]*2 - src[src_idx - 1*src_step] - src[src_idx + 1*src_step];\n"
  28. " dst[dst_idx] = convert_uchar_sat(sqrt((float)(dx*dx + dy*dy)) * scale);\n"
  29. " }\n"
  30. " else\n"
  31. " {\n"
  32. " dst[dst_idx] = 0;\n"
  33. " }\n"
  34. " }\n"
  35. "}\n";
  36. int main(int argc, char** argv)
  37. {
  38. const char* keys =
  39. "{ i input | | specify input image }"
  40. "{ h help | | print help message }";
  41. cv::CommandLineParser args(argc, argv, keys);
  42. if (args.has("help"))
  43. {
  44. cout << "Usage : " << argv[0] << " [options]" << endl;
  45. cout << "Available options:" << endl;
  46. args.printMessage();
  47. return EXIT_SUCCESS;
  48. }
  49. cv::ocl::Context ctx = cv::ocl::Context::getDefault();
  50. if (!ctx.ptr())
  51. {
  52. cerr << "OpenCL is not available" << endl;
  53. return 1;
  54. }
  55. cv::ocl::Device device = cv::ocl::Device::getDefault();
  56. if (!device.compilerAvailable())
  57. {
  58. cerr << "OpenCL compiler is not available" << endl;
  59. return 1;
  60. }
  61. UMat src;
  62. {
  63. string image_file = args.get<string>("i");
  64. if (!image_file.empty())
  65. {
  66. Mat image = imread(samples::findFile(image_file));
  67. if (image.empty())
  68. {
  69. cout << "error read image: " << image_file << endl;
  70. return 1;
  71. }
  72. cvtColor(image, src, COLOR_BGR2GRAY);
  73. }
  74. else
  75. {
  76. Mat frame(cv::Size(640, 480), CV_8U, Scalar::all(128));
  77. Point p(frame.cols / 2, frame.rows / 2);
  78. line(frame, Point(0, frame.rows / 2), Point(frame.cols, frame.rows / 2), 1);
  79. circle(frame, p, 200, Scalar(32, 32, 32), 8, LINE_AA);
  80. string str = "OpenCL";
  81. int baseLine = 0;
  82. Size box = getTextSize(str, FONT_HERSHEY_COMPLEX, 2, 5, &baseLine);
  83. putText(frame, str, Point((frame.cols - box.width) / 2, (frame.rows - box.height) / 2 + baseLine),
  84. FONT_HERSHEY_COMPLEX, 2, Scalar(255, 255, 255), 5, LINE_AA);
  85. frame.copyTo(src);
  86. }
  87. }
  88. cv::String module_name; // empty to disable OpenCL cache
  89. {
  90. cout << "OpenCL program source: " << endl;
  91. cout << "======================================================================================================" << endl;
  92. cout << opencl_kernel_src << endl;
  93. cout << "======================================================================================================" << endl;
  94. //! [Define OpenCL program source]
  95. cv::ocl::ProgramSource source(module_name, "simple", opencl_kernel_src, "");
  96. //! [Define OpenCL program source]
  97. //! [Compile/build OpenCL for current OpenCL device]
  98. cv::String errmsg;
  99. cv::ocl::Program program(source, "", errmsg);
  100. if (program.ptr() == NULL)
  101. {
  102. cerr << "Can't compile OpenCL program:" << endl << errmsg << endl;
  103. return 1;
  104. }
  105. //! [Compile/build OpenCL for current OpenCL device]
  106. if (!errmsg.empty())
  107. {
  108. cout << "OpenCL program build log:" << endl << errmsg << endl;
  109. }
  110. //! [Get OpenCL kernel by name]
  111. cv::ocl::Kernel k("magnutude_filter_8u", program);
  112. if (k.empty())
  113. {
  114. cerr << "Can't get OpenCL kernel" << endl;
  115. return 1;
  116. }
  117. //! [Get OpenCL kernel by name]
  118. UMat result(src.size(), CV_8UC1);
  119. //! [Define kernel parameters and run]
  120. size_t globalSize[2] = {(size_t)src.cols, (size_t)src.rows};
  121. size_t localSize[2] = {8, 8};
  122. bool executionResult = k
  123. .args(
  124. cv::ocl::KernelArg::ReadOnlyNoSize(src), // size is not used (similar to 'dst' size)
  125. cv::ocl::KernelArg::WriteOnly(result),
  126. (float)2.0
  127. )
  128. .run(2, globalSize, localSize, true);
  129. if (!executionResult)
  130. {
  131. cerr << "OpenCL kernel launch failed" << endl;
  132. return 1;
  133. }
  134. //! [Define kernel parameters and run]
  135. imshow("Source", src);
  136. imshow("Result", result);
  137. for (;;)
  138. {
  139. int key = waitKey();
  140. if (key == 27/*ESC*/ || key == 'q' || key == 'Q')
  141. break;
  142. }
  143. }
  144. return 0;
  145. }