letter_recog.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. #include "opencv2/core.hpp"
  2. #include "opencv2/ml.hpp"
  3. #include <cstdio>
  4. #include <vector>
  5. #include <iostream>
  6. using namespace std;
  7. using namespace cv;
  8. using namespace cv::ml;
  9. static void help(char** argv)
  10. {
  11. printf("\nThe sample demonstrates how to train Random Trees classifier\n"
  12. "(or Boosting classifier, or MLP, or Knearest, or Nbayes, or Support Vector Machines - see main()) using the provided dataset.\n"
  13. "\n"
  14. "We use the sample database letter-recognition.data\n"
  15. "from UCI Repository, here is the link:\n"
  16. "\n"
  17. "Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).\n"
  18. "UCI Repository of machine learning databases\n"
  19. "[http://www.ics.uci.edu/~mlearn/MLRepository.html].\n"
  20. "Irvine, CA: University of California, Department of Information and Computer Science.\n"
  21. "\n"
  22. "The dataset consists of 20000 feature vectors along with the\n"
  23. "responses - capital latin letters A..Z.\n"
  24. "The first 16000 (10000 for boosting)) samples are used for training\n"
  25. "and the remaining 4000 (10000 for boosting) - to test the classifier.\n"
  26. "======================================================\n");
  27. printf("\nThis is letter recognition sample.\n"
  28. "The usage: %s [-data=<path to letter-recognition.data>] \\\n"
  29. " [-save=<output XML file for the classifier>] \\\n"
  30. " [-load=<XML file with the pre-trained classifier>] \\\n"
  31. " [-boost|-mlp|-knearest|-nbayes|-svm] # to use boost/mlp/knearest/SVM classifier instead of default Random Trees\n", argv[0] );
  32. }
  33. // This function reads data and responses from the file <filename>
  34. static bool
  35. read_num_class_data( const string& filename, int var_count,
  36. Mat* _data, Mat* _responses )
  37. {
  38. const int M = 1024;
  39. char buf[M+2];
  40. Mat el_ptr(1, var_count, CV_32F);
  41. int i;
  42. vector<int> responses;
  43. _data->release();
  44. _responses->release();
  45. FILE* f = fopen( filename.c_str(), "rt" );
  46. if( !f )
  47. {
  48. cout << "Could not read the database " << filename << endl;
  49. return false;
  50. }
  51. for(;;)
  52. {
  53. char* ptr;
  54. if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )
  55. break;
  56. responses.push_back((int)buf[0]);
  57. ptr = buf+2;
  58. for( i = 0; i < var_count; i++ )
  59. {
  60. int n = 0;
  61. sscanf( ptr, "%f%n", &el_ptr.at<float>(i), &n );
  62. ptr += n + 1;
  63. }
  64. if( i < var_count )
  65. break;
  66. _data->push_back(el_ptr);
  67. }
  68. fclose(f);
  69. Mat(responses).copyTo(*_responses);
  70. cout << "The database " << filename << " is loaded.\n";
  71. return true;
  72. }
  73. template<typename T>
  74. static Ptr<T> load_classifier(const string& filename_to_load)
  75. {
  76. // load classifier from the specified file
  77. Ptr<T> model = StatModel::load<T>( filename_to_load );
  78. if( model.empty() )
  79. cout << "Could not read the classifier " << filename_to_load << endl;
  80. else
  81. cout << "The classifier " << filename_to_load << " is loaded.\n";
  82. return model;
  83. }
  84. static Ptr<TrainData>
  85. prepare_train_data(const Mat& data, const Mat& responses, int ntrain_samples)
  86. {
  87. Mat sample_idx = Mat::zeros( 1, data.rows, CV_8U );
  88. Mat train_samples = sample_idx.colRange(0, ntrain_samples);
  89. train_samples.setTo(Scalar::all(1));
  90. int nvars = data.cols;
  91. Mat var_type( nvars + 1, 1, CV_8U );
  92. var_type.setTo(Scalar::all(VAR_ORDERED));
  93. var_type.at<uchar>(nvars) = VAR_CATEGORICAL;
  94. return TrainData::create(data, ROW_SAMPLE, responses,
  95. noArray(), sample_idx, noArray(), var_type);
  96. }
  97. inline TermCriteria TC(int iters, double eps)
  98. {
  99. return TermCriteria(TermCriteria::MAX_ITER + (eps > 0 ? TermCriteria::EPS : 0), iters, eps);
  100. }
  101. static void test_and_save_classifier(const Ptr<StatModel>& model,
  102. const Mat& data, const Mat& responses,
  103. int ntrain_samples, int rdelta,
  104. const string& filename_to_save)
  105. {
  106. int i, nsamples_all = data.rows;
  107. double train_hr = 0, test_hr = 0;
  108. // compute prediction error on train and test data
  109. for( i = 0; i < nsamples_all; i++ )
  110. {
  111. Mat sample = data.row(i);
  112. float r = model->predict( sample );
  113. r = std::abs(r + rdelta - responses.at<int>(i)) <= FLT_EPSILON ? 1.f : 0.f;
  114. if( i < ntrain_samples )
  115. train_hr += r;
  116. else
  117. test_hr += r;
  118. }
  119. test_hr /= nsamples_all - ntrain_samples;
  120. train_hr = ntrain_samples > 0 ? train_hr/ntrain_samples : 1.;
  121. printf( "Recognition rate: train = %.1f%%, test = %.1f%%\n",
  122. train_hr*100., test_hr*100. );
  123. if( !filename_to_save.empty() )
  124. {
  125. model->save( filename_to_save );
  126. }
  127. }
  128. static bool
  129. build_rtrees_classifier( const string& data_filename,
  130. const string& filename_to_save,
  131. const string& filename_to_load )
  132. {
  133. Mat data;
  134. Mat responses;
  135. bool ok = read_num_class_data( data_filename, 16, &data, &responses );
  136. if( !ok )
  137. return ok;
  138. Ptr<RTrees> model;
  139. int nsamples_all = data.rows;
  140. int ntrain_samples = (int)(nsamples_all*0.8);
  141. // Create or load Random Trees classifier
  142. if( !filename_to_load.empty() )
  143. {
  144. model = load_classifier<RTrees>(filename_to_load);
  145. if( model.empty() )
  146. return false;
  147. ntrain_samples = 0;
  148. }
  149. else
  150. {
  151. // create classifier by using <data> and <responses>
  152. cout << "Training the classifier ...\n";
  153. // Params( int maxDepth, int minSampleCount,
  154. // double regressionAccuracy, bool useSurrogates,
  155. // int maxCategories, const Mat& priors,
  156. // bool calcVarImportance, int nactiveVars,
  157. // TermCriteria termCrit );
  158. Ptr<TrainData> tdata = prepare_train_data(data, responses, ntrain_samples);
  159. model = RTrees::create();
  160. model->setMaxDepth(10);
  161. model->setMinSampleCount(10);
  162. model->setRegressionAccuracy(0);
  163. model->setUseSurrogates(false);
  164. model->setMaxCategories(15);
  165. model->setPriors(Mat());
  166. model->setCalculateVarImportance(true);
  167. model->setActiveVarCount(4);
  168. model->setTermCriteria(TC(100,0.01f));
  169. model->train(tdata);
  170. cout << endl;
  171. }
  172. test_and_save_classifier(model, data, responses, ntrain_samples, 0, filename_to_save);
  173. cout << "Number of trees: " << model->getRoots().size() << endl;
  174. // Print variable importance
  175. Mat var_importance = model->getVarImportance();
  176. if( !var_importance.empty() )
  177. {
  178. double rt_imp_sum = sum( var_importance )[0];
  179. printf("var#\timportance (in %%):\n");
  180. int i, n = (int)var_importance.total();
  181. for( i = 0; i < n; i++ )
  182. printf( "%-2d\t%-4.1f\n", i, 100.f*var_importance.at<float>(i)/rt_imp_sum);
  183. }
  184. return true;
  185. }
  186. static bool
  187. build_boost_classifier( const string& data_filename,
  188. const string& filename_to_save,
  189. const string& filename_to_load )
  190. {
  191. const int class_count = 26;
  192. Mat data;
  193. Mat responses;
  194. Mat weak_responses;
  195. bool ok = read_num_class_data( data_filename, 16, &data, &responses );
  196. if( !ok )
  197. return ok;
  198. int i, j, k;
  199. Ptr<Boost> model;
  200. int nsamples_all = data.rows;
  201. int ntrain_samples = (int)(nsamples_all*0.5);
  202. int var_count = data.cols;
  203. // Create or load Boosted Tree classifier
  204. if( !filename_to_load.empty() )
  205. {
  206. model = load_classifier<Boost>(filename_to_load);
  207. if( model.empty() )
  208. return false;
  209. ntrain_samples = 0;
  210. }
  211. else
  212. {
  213. // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  214. //
  215. // As currently boosted tree classifier in MLL can only be trained
  216. // for 2-class problems, we transform the training database by
  217. // "unrolling" each training sample as many times as the number of
  218. // classes (26) that we have.
  219. //
  220. // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  221. Mat new_data( ntrain_samples*class_count, var_count + 1, CV_32F );
  222. Mat new_responses( ntrain_samples*class_count, 1, CV_32S );
  223. // 1. unroll the database type mask
  224. printf( "Unrolling the database...\n");
  225. for( i = 0; i < ntrain_samples; i++ )
  226. {
  227. const float* data_row = data.ptr<float>(i);
  228. for( j = 0; j < class_count; j++ )
  229. {
  230. float* new_data_row = (float*)new_data.ptr<float>(i*class_count+j);
  231. memcpy(new_data_row, data_row, var_count*sizeof(data_row[0]));
  232. new_data_row[var_count] = (float)j;
  233. new_responses.at<int>(i*class_count + j) = responses.at<int>(i) == j+'A';
  234. }
  235. }
  236. Mat var_type( 1, var_count + 2, CV_8U );
  237. var_type.setTo(Scalar::all(VAR_ORDERED));
  238. var_type.at<uchar>(var_count) = var_type.at<uchar>(var_count+1) = VAR_CATEGORICAL;
  239. Ptr<TrainData> tdata = TrainData::create(new_data, ROW_SAMPLE, new_responses,
  240. noArray(), noArray(), noArray(), var_type);
  241. vector<double> priors(2);
  242. priors[0] = 1;
  243. priors[1] = 26;
  244. cout << "Training the classifier (may take a few minutes)...\n";
  245. model = Boost::create();
  246. model->setBoostType(Boost::GENTLE);
  247. model->setWeakCount(100);
  248. model->setWeightTrimRate(0.95);
  249. model->setMaxDepth(5);
  250. model->setUseSurrogates(false);
  251. model->setPriors(Mat(priors));
  252. model->train(tdata);
  253. cout << endl;
  254. }
  255. Mat temp_sample( 1, var_count + 1, CV_32F );
  256. float* tptr = temp_sample.ptr<float>();
  257. // compute prediction error on train and test data
  258. double train_hr = 0, test_hr = 0;
  259. for( i = 0; i < nsamples_all; i++ )
  260. {
  261. int best_class = 0;
  262. double max_sum = -DBL_MAX;
  263. const float* ptr = data.ptr<float>(i);
  264. for( k = 0; k < var_count; k++ )
  265. tptr[k] = ptr[k];
  266. for( j = 0; j < class_count; j++ )
  267. {
  268. tptr[var_count] = (float)j;
  269. float s = model->predict( temp_sample, noArray(), StatModel::RAW_OUTPUT );
  270. if( max_sum < s )
  271. {
  272. max_sum = s;
  273. best_class = j + 'A';
  274. }
  275. }
  276. double r = std::abs(best_class - responses.at<int>(i)) < FLT_EPSILON ? 1 : 0;
  277. if( i < ntrain_samples )
  278. train_hr += r;
  279. else
  280. test_hr += r;
  281. }
  282. test_hr /= nsamples_all-ntrain_samples;
  283. train_hr = ntrain_samples > 0 ? train_hr/ntrain_samples : 1.;
  284. printf( "Recognition rate: train = %.1f%%, test = %.1f%%\n",
  285. train_hr*100., test_hr*100. );
  286. cout << "Number of trees: " << model->getRoots().size() << endl;
  287. // Save classifier to file if needed
  288. if( !filename_to_save.empty() )
  289. model->save( filename_to_save );
  290. return true;
  291. }
  292. static bool
  293. build_mlp_classifier( const string& data_filename,
  294. const string& filename_to_save,
  295. const string& filename_to_load )
  296. {
  297. const int class_count = 26;
  298. Mat data;
  299. Mat responses;
  300. bool ok = read_num_class_data( data_filename, 16, &data, &responses );
  301. if( !ok )
  302. return ok;
  303. Ptr<ANN_MLP> model;
  304. int nsamples_all = data.rows;
  305. int ntrain_samples = (int)(nsamples_all*0.8);
  306. // Create or load MLP classifier
  307. if( !filename_to_load.empty() )
  308. {
  309. model = load_classifier<ANN_MLP>(filename_to_load);
  310. if( model.empty() )
  311. return false;
  312. ntrain_samples = 0;
  313. }
  314. else
  315. {
  316. // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  317. //
  318. // MLP does not support categorical variables by explicitly.
  319. // So, instead of the output class label, we will use
  320. // a binary vector of <class_count> components for training and,
  321. // therefore, MLP will give us a vector of "probabilities" at the
  322. // prediction stage
  323. //
  324. // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  325. Mat train_data = data.rowRange(0, ntrain_samples);
  326. Mat train_responses = Mat::zeros( ntrain_samples, class_count, CV_32F );
  327. // 1. unroll the responses
  328. cout << "Unrolling the responses...\n";
  329. for( int i = 0; i < ntrain_samples; i++ )
  330. {
  331. int cls_label = responses.at<int>(i) - 'A';
  332. train_responses.at<float>(i, cls_label) = 1.f;
  333. }
  334. // 2. train classifier
  335. int layer_sz[] = { data.cols, 100, 100, class_count };
  336. int nlayers = (int)(sizeof(layer_sz)/sizeof(layer_sz[0]));
  337. Mat layer_sizes( 1, nlayers, CV_32S, layer_sz );
  338. #if 1
  339. int method = ANN_MLP::BACKPROP;
  340. double method_param = 0.001;
  341. int max_iter = 300;
  342. #else
  343. int method = ANN_MLP::RPROP;
  344. double method_param = 0.1;
  345. int max_iter = 1000;
  346. #endif
  347. Ptr<TrainData> tdata = TrainData::create(train_data, ROW_SAMPLE, train_responses);
  348. cout << "Training the classifier (may take a few minutes)...\n";
  349. model = ANN_MLP::create();
  350. model->setLayerSizes(layer_sizes);
  351. model->setActivationFunction(ANN_MLP::SIGMOID_SYM, 0, 0);
  352. model->setTermCriteria(TC(max_iter,0));
  353. model->setTrainMethod(method, method_param);
  354. model->train(tdata);
  355. cout << endl;
  356. }
  357. test_and_save_classifier(model, data, responses, ntrain_samples, 'A', filename_to_save);
  358. return true;
  359. }
  360. static bool
  361. build_knearest_classifier( const string& data_filename, int K )
  362. {
  363. Mat data;
  364. Mat responses;
  365. bool ok = read_num_class_data( data_filename, 16, &data, &responses );
  366. if( !ok )
  367. return ok;
  368. int nsamples_all = data.rows;
  369. int ntrain_samples = (int)(nsamples_all*0.8);
  370. // create classifier by using <data> and <responses>
  371. cout << "Training the classifier ...\n";
  372. Ptr<TrainData> tdata = prepare_train_data(data, responses, ntrain_samples);
  373. Ptr<KNearest> model = KNearest::create();
  374. model->setDefaultK(K);
  375. model->setIsClassifier(true);
  376. model->train(tdata);
  377. cout << endl;
  378. test_and_save_classifier(model, data, responses, ntrain_samples, 0, string());
  379. return true;
  380. }
  381. static bool
  382. build_nbayes_classifier( const string& data_filename )
  383. {
  384. Mat data;
  385. Mat responses;
  386. bool ok = read_num_class_data( data_filename, 16, &data, &responses );
  387. if( !ok )
  388. return ok;
  389. Ptr<NormalBayesClassifier> model;
  390. int nsamples_all = data.rows;
  391. int ntrain_samples = (int)(nsamples_all*0.8);
  392. // create classifier by using <data> and <responses>
  393. cout << "Training the classifier ...\n";
  394. Ptr<TrainData> tdata = prepare_train_data(data, responses, ntrain_samples);
  395. model = NormalBayesClassifier::create();
  396. model->train(tdata);
  397. cout << endl;
  398. test_and_save_classifier(model, data, responses, ntrain_samples, 0, string());
  399. return true;
  400. }
  401. static bool
  402. build_svm_classifier( const string& data_filename,
  403. const string& filename_to_save,
  404. const string& filename_to_load )
  405. {
  406. Mat data;
  407. Mat responses;
  408. bool ok = read_num_class_data( data_filename, 16, &data, &responses );
  409. if( !ok )
  410. return ok;
  411. Ptr<SVM> model;
  412. int nsamples_all = data.rows;
  413. int ntrain_samples = (int)(nsamples_all*0.8);
  414. // Create or load Random Trees classifier
  415. if( !filename_to_load.empty() )
  416. {
  417. model = load_classifier<SVM>(filename_to_load);
  418. if( model.empty() )
  419. return false;
  420. ntrain_samples = 0;
  421. }
  422. else
  423. {
  424. // create classifier by using <data> and <responses>
  425. cout << "Training the classifier ...\n";
  426. Ptr<TrainData> tdata = prepare_train_data(data, responses, ntrain_samples);
  427. model = SVM::create();
  428. model->setType(SVM::C_SVC);
  429. model->setKernel(SVM::LINEAR);
  430. model->setC(1);
  431. model->train(tdata);
  432. cout << endl;
  433. }
  434. test_and_save_classifier(model, data, responses, ntrain_samples, 0, filename_to_save);
  435. return true;
  436. }
  437. int main( int argc, char *argv[] )
  438. {
  439. string filename_to_save = "";
  440. string filename_to_load = "";
  441. string data_filename;
  442. int method = 0;
  443. cv::CommandLineParser parser(argc, argv, "{data|letter-recognition.data|}{save||}{load||}{boost||}"
  444. "{mlp||}{knn knearest||}{nbayes||}{svm||}");
  445. data_filename = samples::findFile(parser.get<string>("data"));
  446. if (parser.has("save"))
  447. filename_to_save = parser.get<string>("save");
  448. if (parser.has("load"))
  449. filename_to_load = samples::findFile(parser.get<string>("load"));
  450. if (parser.has("boost"))
  451. method = 1;
  452. else if (parser.has("mlp"))
  453. method = 2;
  454. else if (parser.has("knearest"))
  455. method = 3;
  456. else if (parser.has("nbayes"))
  457. method = 4;
  458. else if (parser.has("svm"))
  459. method = 5;
  460. help(argv);
  461. if( (method == 0 ?
  462. build_rtrees_classifier( data_filename, filename_to_save, filename_to_load ) :
  463. method == 1 ?
  464. build_boost_classifier( data_filename, filename_to_save, filename_to_load ) :
  465. method == 2 ?
  466. build_mlp_classifier( data_filename, filename_to_save, filename_to_load ) :
  467. method == 3 ?
  468. build_knearest_classifier( data_filename, 10 ) :
  469. method == 4 ?
  470. build_nbayes_classifier( data_filename) :
  471. method == 5 ?
  472. build_svm_classifier( data_filename, filename_to_save, filename_to_load ):
  473. -1) < 0)
  474. return 0;
  475. }