sample_train_landmark_detector.cpp 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #include "opencv2/face.hpp"
  2. #include "opencv2/highgui.hpp"
  3. #include "opencv2/imgcodecs.hpp"
  4. #include "opencv2/objdetect.hpp"
  5. #include "opencv2/imgproc.hpp"
  6. #include <iostream>
  7. #include <vector>
  8. #include <string>
  9. using namespace std;
  10. using namespace cv;
  11. using namespace cv::face;
  12. static bool myDetector(InputArray image, OutputArray faces, CascadeClassifier *face_cascade)
  13. {
  14. Mat gray;
  15. if (image.channels() > 1)
  16. cvtColor(image, gray, COLOR_BGR2GRAY);
  17. else
  18. gray = image.getMat().clone();
  19. equalizeHist(gray, gray);
  20. std::vector<Rect> faces_;
  21. face_cascade->detectMultiScale(gray, faces_, 1.4, 2, CASCADE_SCALE_IMAGE, Size(30, 30));
  22. Mat(faces_).copyTo(faces);
  23. return true;
  24. }
  25. int main(int argc,char** argv){
  26. //Give the path to the directory containing all the files containing data
  27. CommandLineParser parser(argc, argv,
  28. "{ help h usage ? | | give the following arguments in following format }"
  29. "{ annotations a |. | (required) path to annotations txt file [example - /data/annotations.txt] }"
  30. "{ config c | | (required) path to configuration xml file containing parameters for training.[ example - /data/config.xml] }"
  31. "{ model m | | (required) path to configuration xml file containing parameters for training.[ example - /data/model.dat] }"
  32. "{ width w | 460 | The width which you want all images to get to scale the annotations. large images are slow to process [default = 460] }"
  33. "{ height h | 460 | The height which you want all images to get to scale the annotations. large images are slow to process [default = 460] }"
  34. "{ face_cascade f | | Path to the face cascade xml file which you want to use as a detector}"
  35. );
  36. //Read in the input arguments
  37. if (parser.has("help")){
  38. parser.printMessage();
  39. cerr << "TIP: Use absolute paths to avoid any problems with the software!" << endl;
  40. return 0;
  41. }
  42. string directory(parser.get<string>("annotations"));
  43. //default initialisation
  44. Size scale(460,460);
  45. scale = Size(parser.get<int>("width"),parser.get<int>("height"));
  46. if (directory.empty()){
  47. parser.printMessage();
  48. cerr << "The name of the directory from which annotations have to be found is empty" << endl;
  49. return -1;
  50. }
  51. string configfile_name(parser.get<string>("config"));
  52. if (configfile_name.empty()){
  53. parser.printMessage();
  54. cerr << "No configuration file name found which contains the parameters for training" << endl;
  55. return -1;
  56. }
  57. string modelfile_name(parser.get<string>("model"));
  58. if (modelfile_name.empty()){
  59. parser.printMessage();
  60. cerr << "No name for the model_file found in which the trained model has to be saved" << endl;
  61. return -1;
  62. }
  63. string cascade_name(parser.get<string>("face_cascade"));
  64. if (cascade_name.empty()){
  65. parser.printMessage();
  66. cerr << "The name of the cascade classifier to be loaded to detect faces is not found" << endl;
  67. return -1;
  68. }
  69. //create a vector to store names of files in which annotations
  70. //and image names are found
  71. /*The format of the file containing annotations should be of following format
  72. /data/abc/abc.jpg
  73. 123.45,345.65
  74. 321.67,543.89
  75. The above format is similar to HELEN dataset which is used for training model
  76. */
  77. vector<String> filenames;
  78. //reading the files from the given directory
  79. glob(directory + "*.txt",filenames);
  80. //create a pointer to call the base class
  81. //pass the face cascade xml file which you want to pass as a detector
  82. CascadeClassifier face_cascade;
  83. face_cascade.load(cascade_name);
  84. FacemarkKazemi::Params params;
  85. params.configfile = configfile_name;
  86. Ptr<FacemarkKazemi> facemark = FacemarkKazemi::create(params);
  87. facemark->setFaceDetector((FN_FaceDetector)myDetector, &face_cascade);
  88. //create a vector to store image names
  89. vector<String> imagenames;
  90. //create object to get landmarks
  91. vector< vector<Point2f> > trainlandmarks,Trainlandmarks;
  92. //gets landmarks and corresponding image names in both the vectors
  93. //vector to store images
  94. vector<Mat> trainimages;
  95. loadTrainingData(filenames,trainlandmarks,imagenames);
  96. for(unsigned long i=0;i<300;i++){
  97. string imgname = imagenames[i].substr(0, imagenames[i].size()-1);
  98. string img = directory + string(imgname) + ".jpg";
  99. Mat src = imread(img);
  100. if(src.empty()){
  101. cerr<<string("Image "+img+" not found\n.")<<endl;
  102. continue;
  103. }
  104. trainimages.push_back(src);
  105. Trainlandmarks.push_back(trainlandmarks[i]);
  106. }
  107. cout<<"Got data"<<endl;
  108. facemark->training(trainimages,Trainlandmarks,configfile_name,scale,modelfile_name);
  109. cout<<"Training complete"<<endl;
  110. return 0;
  111. }