facerec_lbph.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. /*
  2. * Copyright (c) 2011. Philipp Wagner <bytefish[at]gmx[dot]de>.
  3. * Released to public domain under terms of the BSD Simplified license.
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. * * Redistributions of source code must retain the above copyright
  8. * notice, this list of conditions and the following disclaimer.
  9. * * Redistributions in binary form must reproduce the above copyright
  10. * notice, this list of conditions and the following disclaimer in the
  11. * documentation and/or other materials provided with the distribution.
  12. * * Neither the name of the organization nor the names of its contributors
  13. * may be used to endorse or promote products derived from this software
  14. * without specific prior written permission.
  15. *
  16. * See <http://www.opensource.org/licenses/bsd-license>
  17. */
  18. #include "opencv2/core.hpp"
  19. #include "opencv2/face.hpp"
  20. #include "opencv2/highgui.hpp"
  21. #include <iostream>
  22. #include <fstream>
  23. #include <sstream>
  24. using namespace cv;
  25. using namespace cv::face;
  26. using namespace std;
  27. static void read_csv(const string& filename, vector<Mat>& images, vector<int>& labels, char separator = ';') {
  28. std::ifstream file(filename.c_str(), ifstream::in);
  29. if (!file) {
  30. string error_message = "No valid input file was given, please check the given filename.";
  31. CV_Error(Error::StsBadArg, error_message);
  32. }
  33. string line, path, classlabel;
  34. while (getline(file, line)) {
  35. stringstream liness(line);
  36. getline(liness, path, separator);
  37. getline(liness, classlabel);
  38. if(!path.empty() && !classlabel.empty()) {
  39. images.push_back(imread(path, 0));
  40. labels.push_back(atoi(classlabel.c_str()));
  41. }
  42. }
  43. }
  44. int main(int argc, const char *argv[]) {
  45. // Check for valid command line arguments, print usage
  46. // if no arguments were given.
  47. if (argc != 2) {
  48. cout << "usage: " << argv[0] << " <csv.ext>" << endl;
  49. exit(1);
  50. }
  51. // Get the path to your CSV.
  52. string fn_csv = string(argv[1]);
  53. // These vectors hold the images and corresponding labels.
  54. vector<Mat> images;
  55. vector<int> labels;
  56. // Read in the data. This can fail if no valid
  57. // input filename is given.
  58. try {
  59. read_csv(fn_csv, images, labels);
  60. } catch (const cv::Exception& e) {
  61. cerr << "Error opening file \"" << fn_csv << "\". Reason: " << e.msg << endl;
  62. // nothing more we can do
  63. exit(1);
  64. }
  65. // Quit if there are not enough images for this demo.
  66. if(images.size() <= 1) {
  67. string error_message = "This demo needs at least 2 images to work. Please add more images to your data set!";
  68. CV_Error(Error::StsError, error_message);
  69. }
  70. // The following lines simply get the last images from
  71. // your dataset and remove it from the vector. This is
  72. // done, so that the training data (which we learn the
  73. // cv::LBPHFaceRecognizer on) and the test data we test
  74. // the model with, do not overlap.
  75. Mat testSample = images[images.size() - 1];
  76. int testLabel = labels[labels.size() - 1];
  77. images.pop_back();
  78. labels.pop_back();
  79. // The following lines create an LBPH model for
  80. // face recognition and train it with the images and
  81. // labels read from the given CSV file.
  82. //
  83. // The LBPHFaceRecognizer uses Extended Local Binary Patterns
  84. // (it's probably configurable with other operators at a later
  85. // point), and has the following default values
  86. //
  87. // radius = 1
  88. // neighbors = 8
  89. // grid_x = 8
  90. // grid_y = 8
  91. //
  92. // So if you want a LBPH FaceRecognizer using a radius of
  93. // 2 and 16 neighbors, call the factory method with:
  94. //
  95. // cv::face::LBPHFaceRecognizer::create(2, 16);
  96. //
  97. // And if you want a threshold (e.g. 123.0) call it with its default values:
  98. //
  99. // cv::face::LBPHFaceRecognizer::create(1,8,8,8,123.0)
  100. //
  101. Ptr<LBPHFaceRecognizer> model = LBPHFaceRecognizer::create();
  102. model->train(images, labels);
  103. // The following line predicts the label of a given
  104. // test image:
  105. int predictedLabel = model->predict(testSample);
  106. //
  107. // To get the confidence of a prediction call the model with:
  108. //
  109. // int predictedLabel = -1;
  110. // double confidence = 0.0;
  111. // model->predict(testSample, predictedLabel, confidence);
  112. //
  113. string result_message = format("Predicted class = %d / Actual class = %d.", predictedLabel, testLabel);
  114. cout << result_message << endl;
  115. // First we'll use it to set the threshold of the LBPHFaceRecognizer
  116. // to 0.0 without retraining the model. This can be useful if
  117. // you are evaluating the model:
  118. //
  119. model->setThreshold(0.0);
  120. // Now the threshold of this model is set to 0.0. A prediction
  121. // now returns -1, as it's impossible to have a distance below
  122. // it
  123. predictedLabel = model->predict(testSample);
  124. cout << "Predicted class = " << predictedLabel << endl;
  125. // Show some informations about the model, as there's no cool
  126. // Model data to display as in Eigenfaces/Fisherfaces.
  127. // Due to efficiency reasons the LBP images are not stored
  128. // within the model:
  129. cout << "Model Information:" << endl;
  130. string model_info = format("\tLBPH(radius=%i, neighbors=%i, grid_x=%i, grid_y=%i, threshold=%.2f)",
  131. model->getRadius(),
  132. model->getNeighbors(),
  133. model->getGridX(),
  134. model->getGridY(),
  135. model->getThreshold());
  136. cout << model_info << endl;
  137. // We could get the histograms for example:
  138. vector<Mat> histograms = model->getHistograms();
  139. // But should I really visualize it? Probably the length is interesting:
  140. cout << "Size of the histograms: " << histograms[0].total() << endl;
  141. return 0;
  142. }