digits_svm.cpp 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. #include "opencv2/core.hpp"
  2. #include "opencv2/highgui.hpp"
  3. #include "opencv2/imgcodecs.hpp"
  4. #include "opencv2/imgproc.hpp"
  5. #include "opencv2/ml.hpp"
  6. #include <algorithm>
  7. #include <iostream>
  8. #include <vector>
  9. using namespace cv;
  10. using namespace std;
  11. const int SZ = 20; // size of each digit is SZ x SZ
  12. const int CLASS_N = 10;
  13. const char* DIGITS_FN = "digits.png";
  14. static void help(char** argv)
  15. {
  16. cout <<
  17. "\n"
  18. "SVM and KNearest digit recognition.\n"
  19. "\n"
  20. "Sample loads a dataset of handwritten digits from 'digits.png'.\n"
  21. "Then it trains a SVM and KNearest classifiers on it and evaluates\n"
  22. "their accuracy.\n"
  23. "\n"
  24. "Following preprocessing is applied to the dataset:\n"
  25. " - Moment-based image deskew (see deskew())\n"
  26. " - Digit images are split into 4 10x10 cells and 16-bin\n"
  27. " histogram of oriented gradients is computed for each\n"
  28. " cell\n"
  29. " - Transform histograms to space with Hellinger metric (see [1] (RootSIFT))\n"
  30. "\n"
  31. "\n"
  32. "[1] R. Arandjelovic, A. Zisserman\n"
  33. " \"Three things everyone should know to improve object retrieval\"\n"
  34. " http://www.robots.ox.ac.uk/~vgg/publications/2012/Arandjelovic12/arandjelovic12.pdf\n"
  35. "\n"
  36. "Usage:\n"
  37. << argv[0] << endl;
  38. }
  39. static void split2d(const Mat& image, const Size cell_size, vector<Mat>& cells)
  40. {
  41. int height = image.rows;
  42. int width = image.cols;
  43. int sx = cell_size.width;
  44. int sy = cell_size.height;
  45. cells.clear();
  46. for (int i = 0; i < height; i += sy)
  47. {
  48. for (int j = 0; j < width; j += sx)
  49. {
  50. cells.push_back(image(Rect(j, i, sx, sy)));
  51. }
  52. }
  53. }
  54. static void load_digits(const char* fn, vector<Mat>& digits, vector<int>& labels)
  55. {
  56. digits.clear();
  57. labels.clear();
  58. String filename = samples::findFile(fn);
  59. cout << "Loading " << filename << " ..." << endl;
  60. Mat digits_img = imread(filename, IMREAD_GRAYSCALE);
  61. split2d(digits_img, Size(SZ, SZ), digits);
  62. for (int i = 0; i < CLASS_N; i++)
  63. {
  64. for (size_t j = 0; j < digits.size() / CLASS_N; j++)
  65. {
  66. labels.push_back(i);
  67. }
  68. }
  69. }
  70. static void deskew(const Mat& img, Mat& deskewed_img)
  71. {
  72. Moments m = moments(img);
  73. if (abs(m.mu02) < 0.01)
  74. {
  75. deskewed_img = img.clone();
  76. return;
  77. }
  78. float skew = (float)(m.mu11 / m.mu02);
  79. float M_vals[2][3] = {{1, skew, -0.5f * SZ * skew}, {0, 1, 0}};
  80. Mat M(Size(3, 2), CV_32F);
  81. for (int i = 0; i < M.rows; i++)
  82. {
  83. for (int j = 0; j < M.cols; j++)
  84. {
  85. M.at<float>(i, j) = M_vals[i][j];
  86. }
  87. }
  88. warpAffine(img, deskewed_img, M, Size(SZ, SZ), WARP_INVERSE_MAP | INTER_LINEAR);
  89. }
  90. static void mosaic(const int width, const vector<Mat>& images, Mat& grid)
  91. {
  92. int mat_width = SZ * width;
  93. int mat_height = SZ * (int)ceil((double)images.size() / width);
  94. if (!images.empty())
  95. {
  96. grid = Mat(Size(mat_width, mat_height), images[0].type());
  97. for (size_t i = 0; i < images.size(); i++)
  98. {
  99. Mat location_on_grid = grid(Rect(SZ * ((int)i % width), SZ * ((int)i / width), SZ, SZ));
  100. images[i].copyTo(location_on_grid);
  101. }
  102. }
  103. }
  104. static void evaluate_model(const vector<float>& predictions, const vector<Mat>& digits, const vector<int>& labels, Mat& mos)
  105. {
  106. double err = 0;
  107. for (size_t i = 0; i < predictions.size(); i++)
  108. {
  109. if ((int)predictions[i] != labels[i])
  110. {
  111. err++;
  112. }
  113. }
  114. err /= predictions.size();
  115. cout << cv::format("error: %.2f %%", err * 100) << endl;
  116. int confusion[10][10] = {};
  117. for (size_t i = 0; i < labels.size(); i++)
  118. {
  119. confusion[labels[i]][(int)predictions[i]]++;
  120. }
  121. cout << "confusion matrix:" << endl;
  122. for (int i = 0; i < 10; i++)
  123. {
  124. for (int j = 0; j < 10; j++)
  125. {
  126. cout << cv::format("%2d ", confusion[i][j]);
  127. }
  128. cout << endl;
  129. }
  130. cout << endl;
  131. vector<Mat> vis;
  132. for (size_t i = 0; i < digits.size(); i++)
  133. {
  134. Mat img;
  135. cvtColor(digits[i], img, COLOR_GRAY2BGR);
  136. if ((int)predictions[i] != labels[i])
  137. {
  138. for (int j = 0; j < img.rows; j++)
  139. {
  140. for (int k = 0; k < img.cols; k++)
  141. {
  142. img.at<Vec3b>(j, k)[0] = 0;
  143. img.at<Vec3b>(j, k)[1] = 0;
  144. }
  145. }
  146. }
  147. vis.push_back(img);
  148. }
  149. mosaic(25, vis, mos);
  150. }
  151. static void bincount(const Mat& x, const Mat& weights, const int min_length, vector<double>& bins)
  152. {
  153. double max_x_val = 0;
  154. minMaxLoc(x, NULL, &max_x_val);
  155. bins = vector<double>(max((int)max_x_val, min_length));
  156. for (int i = 0; i < x.rows; i++)
  157. {
  158. for (int j = 0; j < x.cols; j++)
  159. {
  160. bins[x.at<int>(i, j)] += weights.at<float>(i, j);
  161. }
  162. }
  163. }
  164. static void preprocess_hog(const vector<Mat>& digits, Mat& hog)
  165. {
  166. int bin_n = 16;
  167. int half_cell = SZ / 2;
  168. double eps = 1e-7;
  169. hog = Mat(Size(4 * bin_n, (int)digits.size()), CV_32F);
  170. for (size_t img_index = 0; img_index < digits.size(); img_index++)
  171. {
  172. Mat gx;
  173. Sobel(digits[img_index], gx, CV_32F, 1, 0);
  174. Mat gy;
  175. Sobel(digits[img_index], gy, CV_32F, 0, 1);
  176. Mat mag;
  177. Mat ang;
  178. cartToPolar(gx, gy, mag, ang);
  179. Mat bin(ang.size(), CV_32S);
  180. for (int i = 0; i < ang.rows; i++)
  181. {
  182. for (int j = 0; j < ang.cols; j++)
  183. {
  184. bin.at<int>(i, j) = (int)(bin_n * ang.at<float>(i, j) / (2 * CV_PI));
  185. }
  186. }
  187. Mat bin_cells[] = {
  188. bin(Rect(0, 0, half_cell, half_cell)),
  189. bin(Rect(half_cell, 0, half_cell, half_cell)),
  190. bin(Rect(0, half_cell, half_cell, half_cell)),
  191. bin(Rect(half_cell, half_cell, half_cell, half_cell))
  192. };
  193. Mat mag_cells[] = {
  194. mag(Rect(0, 0, half_cell, half_cell)),
  195. mag(Rect(half_cell, 0, half_cell, half_cell)),
  196. mag(Rect(0, half_cell, half_cell, half_cell)),
  197. mag(Rect(half_cell, half_cell, half_cell, half_cell))
  198. };
  199. vector<double> hist;
  200. hist.reserve(4 * bin_n);
  201. for (int i = 0; i < 4; i++)
  202. {
  203. vector<double> partial_hist;
  204. bincount(bin_cells[i], mag_cells[i], bin_n, partial_hist);
  205. hist.insert(hist.end(), partial_hist.begin(), partial_hist.end());
  206. }
  207. // transform to Hellinger kernel
  208. double sum = 0;
  209. for (size_t i = 0; i < hist.size(); i++)
  210. {
  211. sum += hist[i];
  212. }
  213. for (size_t i = 0; i < hist.size(); i++)
  214. {
  215. hist[i] /= sum + eps;
  216. hist[i] = sqrt(hist[i]);
  217. }
  218. double hist_norm = norm(hist);
  219. for (size_t i = 0; i < hist.size(); i++)
  220. {
  221. hog.at<float>((int)img_index, (int)i) = (float)(hist[i] / (hist_norm + eps));
  222. }
  223. }
  224. }
  225. static void shuffle(vector<Mat>& digits, vector<int>& labels)
  226. {
  227. vector<int> shuffled_indexes(digits.size());
  228. for (size_t i = 0; i < digits.size(); i++)
  229. {
  230. shuffled_indexes[i] = (int)i;
  231. }
  232. randShuffle(shuffled_indexes);
  233. vector<Mat> shuffled_digits(digits.size());
  234. vector<int> shuffled_labels(labels.size());
  235. for (size_t i = 0; i < shuffled_indexes.size(); i++)
  236. {
  237. shuffled_digits[shuffled_indexes[i]] = digits[i];
  238. shuffled_labels[shuffled_indexes[i]] = labels[i];
  239. }
  240. digits = shuffled_digits;
  241. labels = shuffled_labels;
  242. }
  243. int main(int /* argc */, char* argv[])
  244. {
  245. help(argv);
  246. vector<Mat> digits;
  247. vector<int> labels;
  248. load_digits(DIGITS_FN, digits, labels);
  249. cout << "preprocessing..." << endl;
  250. // shuffle digits
  251. shuffle(digits, labels);
  252. vector<Mat> digits2;
  253. for (size_t i = 0; i < digits.size(); i++)
  254. {
  255. Mat deskewed_digit;
  256. deskew(digits[i], deskewed_digit);
  257. digits2.push_back(deskewed_digit);
  258. }
  259. Mat samples;
  260. preprocess_hog(digits2, samples);
  261. int train_n = (int)(0.9 * samples.rows);
  262. Mat test_set;
  263. vector<Mat> digits_test(digits2.begin() + train_n, digits2.end());
  264. mosaic(25, digits_test, test_set);
  265. imshow("test set", test_set);
  266. Mat samples_train = samples(Rect(0, 0, samples.cols, train_n));
  267. Mat samples_test = samples(Rect(0, train_n, samples.cols, samples.rows - train_n));
  268. vector<int> labels_train(labels.begin(), labels.begin() + train_n);
  269. vector<int> labels_test(labels.begin() + train_n, labels.end());
  270. Ptr<ml::KNearest> k_nearest;
  271. Ptr<ml::SVM> svm;
  272. vector<float> predictions;
  273. Mat vis;
  274. cout << "training KNearest..." << endl;
  275. k_nearest = ml::KNearest::create();
  276. k_nearest->train(samples_train, ml::ROW_SAMPLE, labels_train);
  277. // predict digits with KNearest
  278. k_nearest->findNearest(samples_test, 4, predictions);
  279. evaluate_model(predictions, digits_test, labels_test, vis);
  280. imshow("KNearest test", vis);
  281. k_nearest.release();
  282. cout << "training SVM..." << endl;
  283. svm = ml::SVM::create();
  284. svm->setGamma(5.383);
  285. svm->setC(2.67);
  286. svm->setKernel(ml::SVM::RBF);
  287. svm->setType(ml::SVM::C_SVC);
  288. svm->train(samples_train, ml::ROW_SAMPLE, labels_train);
  289. // predict digits with SVM
  290. svm->predict(samples_test, predictions);
  291. evaluate_model(predictions, digits_test, labels_test, vis);
  292. imshow("SVM test", vis);
  293. cout << "Saving SVM as \"digits_svm.yml\"..." << endl;
  294. svm->save("digits_svm.yml");
  295. svm.release();
  296. waitKey();
  297. return 0;
  298. }