text_detection.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. #include <algorithm>
  2. #include <cctype>
  3. #include <cmath>
  4. #include <iostream>
  5. #include <limits>
  6. #include <numeric>
  7. #include <stdexcept>
  8. #include <string>
  9. #include <vector>
  10. #include <opencv2/gapi.hpp>
  11. #include <opencv2/gapi/core.hpp>
  12. #include <opencv2/gapi/cpu/gcpukernel.hpp>
  13. #include <opencv2/gapi/infer.hpp>
  14. #include <opencv2/gapi/infer/ie.hpp>
  15. #include <opencv2/gapi/streaming/cap.hpp>
  16. #include <opencv2/highgui.hpp>
  17. #include <opencv2/core/utility.hpp>
  18. const std::string about =
  19. "This is an OpenCV-based version of OMZ Text Detection example";
  20. const std::string keys =
  21. "{ h help | | Print this help message }"
  22. "{ input | | Path to the input video file }"
  23. "{ tdm | text-detection-0004.xml | Path to OpenVINO text detection model (.xml), versions 0003 and 0004 work }"
  24. "{ tdd | CPU | Target device for the text detector (e.g. CPU, GPU, VPU, ...) }"
  25. "{ trm | text-recognition-0012.xml | Path to OpenVINO text recognition model (.xml) }"
  26. "{ trd | CPU | Target device for the text recognition (e.g. CPU, GPU, VPU, ...) }"
  27. "{ bw | 0 | CTC beam search decoder bandwidth, if 0, a CTC greedy decoder is used}"
  28. "{ sset | 0123456789abcdefghijklmnopqrstuvwxyz | Symbol set to use with text recognition decoder. Shouldn't contain symbol #. }"
  29. "{ thr | 0.2 | Text recognition confidence threshold}"
  30. ;
  31. namespace {
  32. std::string weights_path(const std::string &model_path) {
  33. const auto EXT_LEN = 4u;
  34. const auto sz = model_path.size();
  35. CV_Assert(sz > EXT_LEN);
  36. const auto ext = model_path.substr(sz - EXT_LEN);
  37. CV_Assert(cv::toLowerCase(ext) == ".xml");
  38. return model_path.substr(0u, sz - EXT_LEN) + ".bin";
  39. }
  40. //////////////////////////////////////////////////////////////////////
  41. // Taken from OMZ samples as-is
  42. template<typename Iter>
  43. void softmax_and_choose(Iter begin, Iter end, int *argmax, float *prob) {
  44. auto max_element = std::max_element(begin, end);
  45. *argmax = static_cast<int>(std::distance(begin, max_element));
  46. float max_val = *max_element;
  47. double sum = 0;
  48. for (auto i = begin; i != end; i++) {
  49. sum += std::exp((*i) - max_val);
  50. }
  51. if (std::fabs(sum) < std::numeric_limits<double>::epsilon()) {
  52. throw std::logic_error("sum can't be equal to zero");
  53. }
  54. *prob = 1.0f / static_cast<float>(sum);
  55. }
  56. template<typename Iter>
  57. std::vector<float> softmax(Iter begin, Iter end) {
  58. std::vector<float> prob(end - begin, 0.f);
  59. std::transform(begin, end, prob.begin(), [](float x) { return std::exp(x); });
  60. float sum = std::accumulate(prob.begin(), prob.end(), 0.0f);
  61. for (int i = 0; i < static_cast<int>(prob.size()); i++)
  62. prob[i] /= sum;
  63. return prob;
  64. }
  65. struct BeamElement {
  66. std::vector<int> sentence; //!< The sequence of chars that will be a result of the beam element
  67. float prob_blank; //!< The probability that the last char in CTC sequence
  68. //!< for the beam element is the special blank char
  69. float prob_not_blank; //!< The probability that the last char in CTC sequence
  70. //!< for the beam element is NOT the special blank char
  71. float prob() const { //!< The probability of the beam element.
  72. return prob_blank + prob_not_blank;
  73. }
  74. };
  75. std::string CTCGreedyDecoder(const float *data,
  76. const std::size_t sz,
  77. const std::string &alphabet,
  78. const char pad_symbol,
  79. double *conf) {
  80. std::string res = "";
  81. bool prev_pad = false;
  82. *conf = 1;
  83. const auto num_classes = alphabet.length();
  84. for (auto it = data; it != (data+sz); it += num_classes) {
  85. int argmax = 0;
  86. float prob = 0.f;
  87. softmax_and_choose(it, it + num_classes, &argmax, &prob);
  88. (*conf) *= prob;
  89. auto symbol = alphabet[argmax];
  90. if (symbol != pad_symbol) {
  91. if (res.empty() || prev_pad || (!res.empty() && symbol != res.back())) {
  92. prev_pad = false;
  93. res += symbol;
  94. }
  95. } else {
  96. prev_pad = true;
  97. }
  98. }
  99. return res;
  100. }
  101. std::string CTCBeamSearchDecoder(const float *data,
  102. const std::size_t sz,
  103. const std::string &alphabet,
  104. double *conf,
  105. int bandwidth) {
  106. const auto num_classes = alphabet.length();
  107. std::vector<BeamElement> curr;
  108. std::vector<BeamElement> last;
  109. last.push_back(BeamElement{std::vector<int>(), 1.f, 0.f});
  110. for (auto it = data; it != (data+sz); it += num_classes) {
  111. curr.clear();
  112. std::vector<float> prob = softmax(it, it + num_classes);
  113. for(const auto& candidate: last) {
  114. float prob_not_blank = 0.f;
  115. const std::vector<int>& candidate_sentence = candidate.sentence;
  116. if (!candidate_sentence.empty()) {
  117. int n = candidate_sentence.back();
  118. prob_not_blank = candidate.prob_not_blank * prob[n];
  119. }
  120. float prob_blank = candidate.prob() * prob[num_classes - 1];
  121. auto check_res = std::find_if(curr.begin(),
  122. curr.end(),
  123. [&candidate_sentence](const BeamElement& n) {
  124. return n.sentence == candidate_sentence;
  125. });
  126. if (check_res == std::end(curr)) {
  127. curr.push_back(BeamElement{candidate.sentence, prob_blank, prob_not_blank});
  128. } else {
  129. check_res->prob_not_blank += prob_not_blank;
  130. if (check_res->prob_blank != 0.f) {
  131. throw std::logic_error("Probability that the last char in CTC-sequence "
  132. "is the special blank char must be zero here");
  133. }
  134. check_res->prob_blank = prob_blank;
  135. }
  136. for (int i = 0; i < static_cast<int>(num_classes) - 1; i++) {
  137. auto extend = candidate_sentence;
  138. extend.push_back(i);
  139. if (candidate_sentence.size() > 0 && candidate.sentence.back() == i) {
  140. prob_not_blank = prob[i] * candidate.prob_blank;
  141. } else {
  142. prob_not_blank = prob[i] * candidate.prob();
  143. }
  144. auto check_res2 = std::find_if(curr.begin(),
  145. curr.end(),
  146. [&extend](const BeamElement &n) {
  147. return n.sentence == extend;
  148. });
  149. if (check_res2 == std::end(curr)) {
  150. curr.push_back(BeamElement{extend, 0.f, prob_not_blank});
  151. } else {
  152. check_res2->prob_not_blank += prob_not_blank;
  153. }
  154. }
  155. }
  156. sort(curr.begin(), curr.end(), [](const BeamElement &a, const BeamElement &b) -> bool {
  157. return a.prob() > b.prob();
  158. });
  159. last.clear();
  160. int num_to_copy = std::min(bandwidth, static_cast<int>(curr.size()));
  161. for (int b = 0; b < num_to_copy; b++) {
  162. last.push_back(curr[b]);
  163. }
  164. }
  165. *conf = last[0].prob();
  166. std::string res="";
  167. for (const auto& idx: last[0].sentence) {
  168. res += alphabet[idx];
  169. }
  170. return res;
  171. }
  172. //////////////////////////////////////////////////////////////////////
  173. } // anonymous namespace
  174. namespace custom {
  175. namespace {
  176. //////////////////////////////////////////////////////////////////////
  177. // Define networks for this sample
  178. using GMat2 = std::tuple<cv::GMat, cv::GMat>;
  179. G_API_NET(TextDetection,
  180. <GMat2(cv::GMat)>,
  181. "sample.custom.text_detect");
  182. G_API_NET(TextRecognition,
  183. <cv::GMat(cv::GMat)>,
  184. "sample.custom.text_recogn");
  185. // Define custom operations
  186. using GSize = cv::GOpaque<cv::Size>;
  187. using GRRects = cv::GArray<cv::RotatedRect>;
  188. G_API_OP(PostProcess,
  189. <GRRects(cv::GMat,cv::GMat,GSize,float,float)>,
  190. "sample.custom.text.post_proc") {
  191. static cv::GArrayDesc outMeta(const cv::GMatDesc &,
  192. const cv::GMatDesc &,
  193. const cv::GOpaqueDesc &,
  194. float,
  195. float) {
  196. return cv::empty_array_desc();
  197. }
  198. };
  199. using GMats = cv::GArray<cv::GMat>;
  200. G_API_OP(CropLabels,
  201. <GMats(cv::GMat,GRRects,GSize)>,
  202. "sample.custom.text.crop") {
  203. static cv::GArrayDesc outMeta(const cv::GMatDesc &,
  204. const cv::GArrayDesc &,
  205. const cv::GOpaqueDesc &) {
  206. return cv::empty_array_desc();
  207. }
  208. };
  209. //////////////////////////////////////////////////////////////////////
  210. // Implement custom operations
  211. GAPI_OCV_KERNEL(OCVPostProcess, PostProcess) {
  212. static void run(const cv::Mat &link,
  213. const cv::Mat &segm,
  214. const cv::Size &img_size,
  215. const float link_threshold,
  216. const float segm_threshold,
  217. std::vector<cv::RotatedRect> &out) {
  218. // NOTE: Taken from the OMZ text detection sample almost as-is
  219. const int kMinArea = 300;
  220. const int kMinHeight = 10;
  221. const float *link_data_pointer = link.ptr<float>();
  222. std::vector<float> link_data(link_data_pointer, link_data_pointer + link.total());
  223. link_data = transpose4d(link_data, dimsToShape(link.size), {0, 2, 3, 1});
  224. softmax(link_data);
  225. link_data = sliceAndGetSecondChannel(link_data);
  226. std::vector<int> new_link_data_shape = {
  227. link.size[0],
  228. link.size[2],
  229. link.size[3],
  230. link.size[1]/2,
  231. };
  232. const float *cls_data_pointer = segm.ptr<float>();
  233. std::vector<float> cls_data(cls_data_pointer, cls_data_pointer + segm.total());
  234. cls_data = transpose4d(cls_data, dimsToShape(segm.size), {0, 2, 3, 1});
  235. softmax(cls_data);
  236. cls_data = sliceAndGetSecondChannel(cls_data);
  237. std::vector<int> new_cls_data_shape = {
  238. segm.size[0],
  239. segm.size[2],
  240. segm.size[3],
  241. segm.size[1]/2,
  242. };
  243. out = maskToBoxes(decodeImageByJoin(cls_data, new_cls_data_shape,
  244. link_data, new_link_data_shape,
  245. segm_threshold, link_threshold),
  246. static_cast<float>(kMinArea),
  247. static_cast<float>(kMinHeight),
  248. img_size);
  249. }
  250. static std::vector<std::size_t> dimsToShape(const cv::MatSize &sz) {
  251. const int n_dims = sz.dims();
  252. std::vector<std::size_t> result;
  253. result.reserve(n_dims);
  254. // cv::MatSize is not iterable...
  255. for (int i = 0; i < n_dims; i++) {
  256. result.emplace_back(static_cast<std::size_t>(sz[i]));
  257. }
  258. return result;
  259. }
  260. static void softmax(std::vector<float> &rdata) {
  261. // NOTE: Taken from the OMZ text detection sample almost as-is
  262. const size_t last_dim = 2;
  263. for (size_t i = 0 ; i < rdata.size(); i+=last_dim) {
  264. float m = std::max(rdata[i], rdata[i+1]);
  265. rdata[i] = std::exp(rdata[i] - m);
  266. rdata[i + 1] = std::exp(rdata[i + 1] - m);
  267. float s = rdata[i] + rdata[i + 1];
  268. rdata[i] /= s;
  269. rdata[i + 1] /= s;
  270. }
  271. }
  272. static std::vector<float> transpose4d(const std::vector<float> &data,
  273. const std::vector<size_t> &shape,
  274. const std::vector<size_t> &axes) {
  275. // NOTE: Taken from the OMZ text detection sample almost as-is
  276. if (shape.size() != axes.size())
  277. throw std::runtime_error("Shape and axes must have the same dimension.");
  278. for (size_t a : axes) {
  279. if (a >= shape.size())
  280. throw std::runtime_error("Axis must be less than dimension of shape.");
  281. }
  282. size_t total_size = shape[0]*shape[1]*shape[2]*shape[3];
  283. std::vector<size_t> steps {
  284. shape[axes[1]]*shape[axes[2]]*shape[axes[3]],
  285. shape[axes[2]]*shape[axes[3]],
  286. shape[axes[3]],
  287. 1
  288. };
  289. size_t source_data_idx = 0;
  290. std::vector<float> new_data(total_size, 0);
  291. std::vector<size_t> ids(shape.size());
  292. for (ids[0] = 0; ids[0] < shape[0]; ids[0]++) {
  293. for (ids[1] = 0; ids[1] < shape[1]; ids[1]++) {
  294. for (ids[2] = 0; ids[2] < shape[2]; ids[2]++) {
  295. for (ids[3]= 0; ids[3] < shape[3]; ids[3]++) {
  296. size_t new_data_idx = ids[axes[0]]*steps[0] + ids[axes[1]]*steps[1] +
  297. ids[axes[2]]*steps[2] + ids[axes[3]]*steps[3];
  298. new_data[new_data_idx] = data[source_data_idx++];
  299. }
  300. }
  301. }
  302. }
  303. return new_data;
  304. }
  305. static std::vector<float> sliceAndGetSecondChannel(const std::vector<float> &data) {
  306. // NOTE: Taken from the OMZ text detection sample almost as-is
  307. std::vector<float> new_data(data.size() / 2, 0);
  308. for (size_t i = 0; i < data.size() / 2; i++) {
  309. new_data[i] = data[2 * i + 1];
  310. }
  311. return new_data;
  312. }
  313. static void join(const int p1,
  314. const int p2,
  315. std::unordered_map<int, int> &group_mask) {
  316. // NOTE: Taken from the OMZ text detection sample almost as-is
  317. const int root1 = findRoot(p1, group_mask);
  318. const int root2 = findRoot(p2, group_mask);
  319. if (root1 != root2) {
  320. group_mask[root1] = root2;
  321. }
  322. }
  323. static cv::Mat decodeImageByJoin(const std::vector<float> &cls_data,
  324. const std::vector<int> &cls_data_shape,
  325. const std::vector<float> &link_data,
  326. const std::vector<int> &link_data_shape,
  327. float cls_conf_threshold,
  328. float link_conf_threshold) {
  329. // NOTE: Taken from the OMZ text detection sample almost as-is
  330. const int h = cls_data_shape[1];
  331. const int w = cls_data_shape[2];
  332. std::vector<uchar> pixel_mask(h * w, 0);
  333. std::unordered_map<int, int> group_mask;
  334. std::vector<cv::Point> points;
  335. for (int i = 0; i < static_cast<int>(pixel_mask.size()); i++) {
  336. pixel_mask[i] = cls_data[i] >= cls_conf_threshold;
  337. if (pixel_mask[i]) {
  338. points.emplace_back(i % w, i / w);
  339. group_mask[i] = -1;
  340. }
  341. }
  342. std::vector<uchar> link_mask(link_data.size(), 0);
  343. for (size_t i = 0; i < link_mask.size(); i++) {
  344. link_mask[i] = link_data[i] >= link_conf_threshold;
  345. }
  346. size_t neighbours = size_t(link_data_shape[3]);
  347. for (const auto &point : points) {
  348. size_t neighbour = 0;
  349. for (int ny = point.y - 1; ny <= point.y + 1; ny++) {
  350. for (int nx = point.x - 1; nx <= point.x + 1; nx++) {
  351. if (nx == point.x && ny == point.y)
  352. continue;
  353. if (nx >= 0 && nx < w && ny >= 0 && ny < h) {
  354. uchar pixel_value = pixel_mask[size_t(ny) * size_t(w) + size_t(nx)];
  355. uchar link_value = link_mask[(size_t(point.y) * size_t(w) + size_t(point.x))
  356. *neighbours + neighbour];
  357. if (pixel_value && link_value) {
  358. join(point.x + point.y * w, nx + ny * w, group_mask);
  359. }
  360. }
  361. neighbour++;
  362. }
  363. }
  364. }
  365. return get_all(points, w, h, group_mask);
  366. }
  367. static cv::Mat get_all(const std::vector<cv::Point> &points,
  368. const int w,
  369. const int h,
  370. std::unordered_map<int, int> &group_mask) {
  371. // NOTE: Taken from the OMZ text detection sample almost as-is
  372. std::unordered_map<int, int> root_map;
  373. cv::Mat mask(h, w, CV_32S, cv::Scalar(0));
  374. for (const auto &point : points) {
  375. int point_root = findRoot(point.x + point.y * w, group_mask);
  376. if (root_map.find(point_root) == root_map.end()) {
  377. root_map.emplace(point_root, static_cast<int>(root_map.size() + 1));
  378. }
  379. mask.at<int>(point.x + point.y * w) = root_map[point_root];
  380. }
  381. return mask;
  382. }
  383. static int findRoot(const int point,
  384. std::unordered_map<int, int> &group_mask) {
  385. // NOTE: Taken from the OMZ text detection sample almost as-is
  386. int root = point;
  387. bool update_parent = false;
  388. while (group_mask.at(root) != -1) {
  389. root = group_mask.at(root);
  390. update_parent = true;
  391. }
  392. if (update_parent) {
  393. group_mask[point] = root;
  394. }
  395. return root;
  396. }
  397. static std::vector<cv::RotatedRect> maskToBoxes(const cv::Mat &mask,
  398. const float min_area,
  399. const float min_height,
  400. const cv::Size &image_size) {
  401. // NOTE: Taken from the OMZ text detection sample almost as-is
  402. std::vector<cv::RotatedRect> bboxes;
  403. double min_val = 0.;
  404. double max_val = 0.;
  405. cv::minMaxLoc(mask, &min_val, &max_val);
  406. int max_bbox_idx = static_cast<int>(max_val);
  407. cv::Mat resized_mask;
  408. cv::resize(mask, resized_mask, image_size, 0, 0, cv::INTER_NEAREST);
  409. for (int i = 1; i <= max_bbox_idx; i++) {
  410. cv::Mat bbox_mask = resized_mask == i;
  411. std::vector<std::vector<cv::Point>> contours;
  412. cv::findContours(bbox_mask, contours, cv::RETR_CCOMP, cv::CHAIN_APPROX_SIMPLE);
  413. if (contours.empty())
  414. continue;
  415. cv::RotatedRect r = cv::minAreaRect(contours[0]);
  416. if (std::min(r.size.width, r.size.height) < min_height)
  417. continue;
  418. if (r.size.area() < min_area)
  419. continue;
  420. bboxes.emplace_back(r);
  421. }
  422. return bboxes;
  423. }
  424. }; // GAPI_OCV_KERNEL(PostProcess)
  425. GAPI_OCV_KERNEL(OCVCropLabels, CropLabels) {
  426. static void run(const cv::Mat &image,
  427. const std::vector<cv::RotatedRect> &detections,
  428. const cv::Size &outSize,
  429. std::vector<cv::Mat> &out) {
  430. out.clear();
  431. out.reserve(detections.size());
  432. cv::Mat crop(outSize, CV_8UC3, cv::Scalar(0));
  433. cv::Mat gray(outSize, CV_8UC1, cv::Scalar(0));
  434. std::vector<int> blob_shape = {1,1,outSize.height,outSize.width};
  435. for (auto &&rr : detections) {
  436. std::vector<cv::Point2f> points(4);
  437. rr.points(points.data());
  438. const auto top_left_point_idx = topLeftPointIdx(points);
  439. cv::Point2f point0 = points[static_cast<size_t>(top_left_point_idx)];
  440. cv::Point2f point1 = points[(top_left_point_idx + 1) % 4];
  441. cv::Point2f point2 = points[(top_left_point_idx + 2) % 4];
  442. std::vector<cv::Point2f> from{point0, point1, point2};
  443. std::vector<cv::Point2f> to{
  444. cv::Point2f(0.0f, 0.0f),
  445. cv::Point2f(static_cast<float>(outSize.width-1), 0.0f),
  446. cv::Point2f(static_cast<float>(outSize.width-1),
  447. static_cast<float>(outSize.height-1))
  448. };
  449. cv::Mat M = cv::getAffineTransform(from, to);
  450. cv::warpAffine(image, crop, M, outSize);
  451. cv::cvtColor(crop, gray, cv::COLOR_BGR2GRAY);
  452. cv::Mat blob;
  453. gray.convertTo(blob, CV_32F);
  454. out.push_back(blob.reshape(1, blob_shape)); // pass as 1,1,H,W instead of H,W
  455. }
  456. }
  457. static int topLeftPointIdx(const std::vector<cv::Point2f> &points) {
  458. // NOTE: Taken from the OMZ text detection sample almost as-is
  459. cv::Point2f most_left(std::numeric_limits<float>::max(),
  460. std::numeric_limits<float>::max());
  461. cv::Point2f almost_most_left(std::numeric_limits<float>::max(),
  462. std::numeric_limits<float>::max());
  463. int most_left_idx = -1;
  464. int almost_most_left_idx = -1;
  465. for (size_t i = 0; i < points.size() ; i++) {
  466. if (most_left.x > points[i].x) {
  467. if (most_left.x < std::numeric_limits<float>::max()) {
  468. almost_most_left = most_left;
  469. almost_most_left_idx = most_left_idx;
  470. }
  471. most_left = points[i];
  472. most_left_idx = static_cast<int>(i);
  473. }
  474. if (almost_most_left.x > points[i].x && points[i] != most_left) {
  475. almost_most_left = points[i];
  476. almost_most_left_idx = static_cast<int>(i);
  477. }
  478. }
  479. if (almost_most_left.y < most_left.y) {
  480. most_left = almost_most_left;
  481. most_left_idx = almost_most_left_idx;
  482. }
  483. return most_left_idx;
  484. }
  485. }; // GAPI_OCV_KERNEL(CropLabels)
  486. } // anonymous namespace
  487. } // namespace custom
  488. namespace vis {
  489. namespace {
  490. void drawRotatedRect(cv::Mat &m, const cv::RotatedRect &rc) {
  491. std::vector<cv::Point2f> tmp_points(5);
  492. rc.points(tmp_points.data());
  493. tmp_points[4] = tmp_points[0];
  494. auto prev = tmp_points.begin(), it = prev+1;
  495. for (; it != tmp_points.end(); ++it) {
  496. cv::line(m, *prev, *it, cv::Scalar(50, 205, 50), 2);
  497. prev = it;
  498. }
  499. }
  500. void drawText(cv::Mat &m, const cv::RotatedRect &rc, const std::string &str) {
  501. const int fface = cv::FONT_HERSHEY_SIMPLEX;
  502. const double scale = 0.7;
  503. const int thick = 1;
  504. int base = 0;
  505. const auto text_size = cv::getTextSize(str, fface, scale, thick, &base);
  506. std::vector<cv::Point2f> tmp_points(4);
  507. rc.points(tmp_points.data());
  508. const auto tl_point_idx = custom::OCVCropLabels::topLeftPointIdx(tmp_points);
  509. cv::Point text_pos = tmp_points[tl_point_idx];
  510. text_pos.x = std::max(0, text_pos.x);
  511. text_pos.y = std::max(text_size.height, text_pos.y);
  512. cv::rectangle(m,
  513. text_pos + cv::Point{0, base},
  514. text_pos + cv::Point{text_size.width, -text_size.height},
  515. CV_RGB(50, 205, 50),
  516. cv::FILLED);
  517. const auto white = CV_RGB(255, 255, 255);
  518. cv::putText(m, str, text_pos, fface, scale, white, thick, 8);
  519. }
  520. } // anonymous namespace
  521. } // namespace vis
  522. int main(int argc, char *argv[])
  523. {
  524. cv::CommandLineParser cmd(argc, argv, keys);
  525. cmd.about(about);
  526. if (cmd.has("help")) {
  527. cmd.printMessage();
  528. return 0;
  529. }
  530. const auto input_file_name = cmd.get<std::string>("input");
  531. const auto tdet_model_path = cmd.get<std::string>("tdm");
  532. const auto trec_model_path = cmd.get<std::string>("trm");
  533. const auto tdet_target_dev = cmd.get<std::string>("tdd");
  534. const auto trec_target_dev = cmd.get<std::string>("trd");
  535. const auto ctc_beam_dec_bw = cmd.get<int>("bw");
  536. const auto dec_conf_thresh = cmd.get<double>("thr");
  537. const auto pad_symbol = '#';
  538. const auto symbol_set = cmd.get<std::string>("sset") + pad_symbol;
  539. cv::GMat in;
  540. cv::GOpaque<cv::Size> in_rec_sz;
  541. cv::GMat link, segm;
  542. std::tie(link, segm) = cv::gapi::infer<custom::TextDetection>(in);
  543. cv::GOpaque<cv::Size> size = cv::gapi::streaming::size(in);
  544. cv::GArray<cv::RotatedRect> rrs = custom::PostProcess::on(link, segm, size, 0.8f, 0.8f);
  545. cv::GArray<cv::GMat> labels = custom::CropLabels::on(in, rrs, in_rec_sz);
  546. cv::GArray<cv::GMat> text = cv::gapi::infer2<custom::TextRecognition>(in, labels);
  547. cv::GComputation graph(cv::GIn(in, in_rec_sz),
  548. cv::GOut(cv::gapi::copy(in), rrs, text));
  549. // Text detection network
  550. auto tdet_net = cv::gapi::ie::Params<custom::TextDetection> {
  551. tdet_model_path, // path to topology IR
  552. weights_path(tdet_model_path), // path to weights
  553. tdet_target_dev, // device specifier
  554. }.cfgOutputLayers({"model/link_logits_/add", "model/segm_logits/add"});
  555. auto trec_net = cv::gapi::ie::Params<custom::TextRecognition> {
  556. trec_model_path, // path to topology IR
  557. weights_path(trec_model_path), // path to weights
  558. trec_target_dev, // device specifier
  559. };
  560. auto networks = cv::gapi::networks(tdet_net, trec_net);
  561. auto kernels = cv::gapi::kernels< custom::OCVPostProcess
  562. , custom::OCVCropLabels
  563. >();
  564. auto pipeline = graph.compileStreaming(cv::compile_args(kernels, networks));
  565. std::cout << "Reading " << input_file_name << std::endl;
  566. // Input stream
  567. auto in_src = cv::gapi::wip::make_src<cv::gapi::wip::GCaptureSource>(input_file_name);
  568. // Text recognition input size (also an input parameter to the graph)
  569. auto in_rsz = cv::Size{ 120, 32 };
  570. // Set the pipeline source & start the pipeline
  571. pipeline.setSource(cv::gin(in_src, in_rsz));
  572. pipeline.start();
  573. // Declare the output data & run the processing loop
  574. cv::TickMeter tm;
  575. cv::Mat image;
  576. std::vector<cv::RotatedRect> out_rcs;
  577. std::vector<cv::Mat> out_text;
  578. tm.start();
  579. int frames = 0;
  580. while (pipeline.pull(cv::gout(image, out_rcs, out_text))) {
  581. frames++;
  582. CV_Assert(out_rcs.size() == out_text.size());
  583. const auto num_labels = out_rcs.size();
  584. std::vector<cv::Point2f> tmp_points(4);
  585. for (std::size_t l = 0; l < num_labels; l++) {
  586. // Decode the recognized text in the rectangle
  587. const auto &blob = out_text[l];
  588. const float *data = blob.ptr<float>();
  589. const auto sz = blob.total();
  590. double conf = 1.0;
  591. const std::string res = ctc_beam_dec_bw == 0
  592. ? CTCGreedyDecoder(data, sz, symbol_set, pad_symbol, &conf)
  593. : CTCBeamSearchDecoder(data, sz, symbol_set, &conf, ctc_beam_dec_bw);
  594. // Draw a bounding box for this rotated rectangle
  595. const auto &rc = out_rcs[l];
  596. vis::drawRotatedRect(image, rc);
  597. // Draw text, if decoded
  598. if (conf >= dec_conf_thresh) {
  599. vis::drawText(image, rc, res);
  600. }
  601. }
  602. tm.stop();
  603. cv::imshow("Out", image);
  604. cv::waitKey(1);
  605. tm.start();
  606. }
  607. tm.stop();
  608. std::cout << "Processed " << frames << " frames"
  609. << " (" << frames / tm.getTimeSec() << " FPS)" << std::endl;
  610. return 0;
  611. }