bagofwords_classification.cpp 115 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646
  1. #include <iostream>
  2. #include "opencv2/opencv_modules.hpp"
  3. #ifdef HAVE_OPENCV_ML
  4. #include "opencv2/imgcodecs.hpp"
  5. #include "opencv2/highgui.hpp"
  6. #include "opencv2/imgproc.hpp"
  7. #include "opencv2/features2d.hpp"
  8. #include "opencv2/xfeatures2d.hpp"
  9. #include "opencv2/ml.hpp"
  10. #include <fstream>
  11. #include <memory>
  12. #include <functional>
  13. #ifdef _WIN32
  14. #define WIN32_LEAN_AND_MEAN
  15. #include <windows.h>
  16. #undef min
  17. #undef max
  18. #include "sys/types.h"
  19. #endif
  20. #include <sys/stat.h>
  21. #define DEBUG_DESC_PROGRESS
  22. using namespace cv;
  23. using namespace cv::xfeatures2d;
  24. using namespace cv::ml;
  25. using namespace std;
  26. const string paramsFile = "params.xml";
  27. const string vocabularyFile = "vocabulary.xml.gz";
  28. const string bowImageDescriptorsDir = "/bowImageDescriptors";
  29. const string svmsDir = "/svms";
  30. const string plotsDir = "/plots";
  31. static void help(char** argv)
  32. {
  33. cout << "\nThis program shows how to read in, train on and produce test results for the PASCAL VOC (Visual Object Challenge) data. \n"
  34. << "It shows how to use detectors, descriptors and recognition methods \n"
  35. "Using OpenCV version %s\n" << CV_VERSION << "\n"
  36. << "Call: \n"
  37. << "Format:\n ./" << argv[0] << " [VOC path] [result directory] \n"
  38. << " or: \n"
  39. << " ./" << argv[0] << " [VOC path] [result directory] [feature detector] [descriptor extractor] [descriptor matcher] \n"
  40. << "\n"
  41. << "Input parameters: \n"
  42. << "[VOC path] Path to Pascal VOC data (e.g. /home/my/VOCdevkit/VOC2010). Note: VOC2007-VOC2010 are supported. \n"
  43. << "[result directory] Path to result diractory. Following folders will be created in [result directory]: \n"
  44. << " bowImageDescriptors - to store image descriptors, \n"
  45. << " svms - to store trained svms, \n"
  46. << " plots - to store files for plots creating. \n"
  47. << "[feature detector] Feature detector name (e.g. SURF, FAST...) - see createFeatureDetector() function in detectors.cpp \n"
  48. << " Currently 12/2010, this is FAST, STAR, SIFT, SURF, MSER, GFTT, HARRIS \n"
  49. << "[descriptor extractor] Descriptor extractor name (e.g. SURF, SIFT) - see createDescriptorExtractor() function in descriptors.cpp \n"
  50. << " Currently 12/2010, this is SURF, OpponentSIFT, SIFT, OpponentSURF, BRIEF \n"
  51. << "[descriptor matcher] Descriptor matcher name (e.g. BruteForce) - see createDescriptorMatcher() function in matchers.cpp \n"
  52. << " Currently 12/2010, this is BruteForce, BruteForce-L1, FlannBased, BruteForce-Hamming, BruteForce-HammingLUT \n"
  53. << "\n";
  54. }
  55. static void makeDir( const string& dir )
  56. {
  57. #if defined WIN32 || defined _WIN32
  58. CreateDirectory( dir.c_str(), 0 );
  59. #else
  60. mkdir( dir.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH );
  61. #endif
  62. }
  63. static void makeUsedDirs( const string& rootPath )
  64. {
  65. makeDir(rootPath + bowImageDescriptorsDir);
  66. makeDir(rootPath + svmsDir);
  67. makeDir(rootPath + plotsDir);
  68. }
  69. /****************************************************************************************\
  70. * Classes to work with PASCAL VOC dataset *
  71. \****************************************************************************************/
  72. //
  73. // TODO: refactor this part of the code
  74. //
  75. //used to specify the (sub-)dataset over which operations are performed
  76. enum ObdDatasetType {CV_OBD_TRAIN, CV_OBD_TEST};
  77. class ObdObject
  78. {
  79. public:
  80. string object_class;
  81. Rect boundingBox;
  82. };
  83. //extended object data specific to VOC
  84. enum VocPose {CV_VOC_POSE_UNSPECIFIED, CV_VOC_POSE_FRONTAL, CV_VOC_POSE_REAR, CV_VOC_POSE_LEFT, CV_VOC_POSE_RIGHT};
  85. class VocObjectData
  86. {
  87. public:
  88. bool difficult;
  89. bool occluded;
  90. bool truncated;
  91. VocPose pose;
  92. };
  93. //enum VocDataset {CV_VOC2007, CV_VOC2008, CV_VOC2009, CV_VOC2010};
  94. enum VocPlotType {CV_VOC_PLOT_SCREEN, CV_VOC_PLOT_PNG};
  95. enum VocGT {CV_VOC_GT_NONE, CV_VOC_GT_DIFFICULT, CV_VOC_GT_PRESENT};
  96. enum VocConfCond {CV_VOC_CCOND_RECALL, CV_VOC_CCOND_SCORETHRESH};
  97. enum VocTask {CV_VOC_TASK_CLASSIFICATION, CV_VOC_TASK_DETECTION};
  98. class ObdImage
  99. {
  100. public:
  101. ObdImage(string p_id, string p_path) : id(p_id), path(p_path) {}
  102. string id;
  103. string path;
  104. };
  105. //used by getDetectorGroundTruth to sort a two dimensional list of floats in descending order
  106. class ObdScoreIndexSorter
  107. {
  108. public:
  109. float score;
  110. int image_idx;
  111. int obj_idx;
  112. bool operator < (const ObdScoreIndexSorter& compare) const {return (score < compare.score);}
  113. };
  114. class VocData
  115. {
  116. public:
  117. VocData( const string& vocPath, bool useTestDataset )
  118. { initVoc( vocPath, useTestDataset ); }
  119. ~VocData(){}
  120. /* functions for returning classification/object data for multiple images given an object class */
  121. void getClassImages(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present);
  122. void getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects);
  123. void getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects, vector<vector<VocObjectData> >& object_data, vector<VocGT>& ground_truth);
  124. /* functions for returning object data for a single image given an image id */
  125. ObdImage getObjects(const string& id, vector<ObdObject>& objects);
  126. ObdImage getObjects(const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data);
  127. ObdImage getObjects(const string& obj_class, const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data, VocGT& ground_truth);
  128. /* functions for returning the ground truth (present/absent) for groups of images */
  129. void getClassifierGroundTruth(const string& obj_class, const vector<ObdImage>& images, vector<char>& ground_truth);
  130. void getClassifierGroundTruth(const string& obj_class, const vector<string>& images, vector<char>& ground_truth);
  131. int getDetectorGroundTruth(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<Rect> >& bounding_boxes, const vector<vector<float> >& scores, vector<vector<char> >& ground_truth, vector<vector<char> >& detection_difficult, bool ignore_difficult = true);
  132. /* functions for writing VOC-compatible results files */
  133. void writeClassifierResultsFile(const string& out_dir, const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<float>& scores, const int competition = 1, const bool overwrite_ifexists = false);
  134. /* functions for calculating metrics from a set of classification/detection results */
  135. string getResultsFilename(const string& obj_class, const VocTask task, const ObdDatasetType dataset, const int competition = -1, const int number = -1);
  136. void calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking);
  137. void calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap);
  138. void calcClassifierPrecRecall(const string& input_file, vector<float>& precision, vector<float>& recall, float& ap, bool outputRankingFile = false);
  139. /* functions for calculating confusion matrices */
  140. void calcClassifierConfMatRow(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values);
  141. void calcDetectorConfMatRow(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<float> >& scores, const vector<vector<Rect> >& bounding_boxes, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values, bool ignore_difficult = true);
  142. /* functions for outputting gnuplot output files */
  143. void savePrecRecallToGnuplot(const string& output_file, const vector<float>& precision, const vector<float>& recall, const float ap, const string title = string(), const VocPlotType plot_type = CV_VOC_PLOT_SCREEN);
  144. /* functions for reading in result/ground truth files */
  145. void readClassifierGroundTruth(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present);
  146. void readClassifierResultsFile(const std:: string& input_file, vector<ObdImage>& images, vector<float>& scores);
  147. void readDetectorResultsFile(const string& input_file, vector<ObdImage>& images, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes);
  148. /* functions for getting dataset info */
  149. const vector<string>& getObjectClasses();
  150. string getResultsDirectory();
  151. protected:
  152. void initVoc( const string& vocPath, const bool useTestDataset );
  153. void initVoc2007to2010( const string& vocPath, const bool useTestDataset);
  154. void readClassifierGroundTruth(const string& filename, vector<string>& image_codes, vector<char>& object_present);
  155. void readClassifierResultsFile(const string& input_file, vector<string>& image_codes, vector<float>& scores);
  156. void readDetectorResultsFile(const string& input_file, vector<string>& image_codes, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes);
  157. void extractVocObjects(const string filename, vector<ObdObject>& objects, vector<VocObjectData>& object_data);
  158. string getImagePath(const string& input_str);
  159. void getClassImages_impl(const string& obj_class, const string& dataset_str, vector<ObdImage>& images, vector<char>& object_present);
  160. void calcPrecRecall_impl(const vector<char>& ground_truth, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking, int recall_normalization = -1);
  161. //test two bounding boxes to see if they meet the overlap criteria defined in the VOC documentation
  162. float testBoundingBoxesForOverlap(const Rect detection, const Rect ground_truth);
  163. //extract class and dataset name from a VOC-standard classification/detection results filename
  164. void extractDataFromResultsFilename(const string& input_file, string& class_name, string& dataset_name);
  165. //get classifier ground truth for a single image
  166. bool getClassifierGroundTruthImage(const string& obj_class, const string& id);
  167. //utility functions
  168. void getSortOrder(const vector<float>& values, vector<size_t>& order, bool descending = true);
  169. int stringToInteger(const string input_str);
  170. void readFileToString(const string filename, string& file_contents);
  171. string integerToString(const int input_int);
  172. string checkFilenamePathsep(const string filename, bool add_trailing_slash = false);
  173. void convertImageCodesToObdImages(const vector<string>& image_codes, vector<ObdImage>& images);
  174. int extractXMLBlock(const string src, const string tag, const int searchpos, string& tag_contents);
  175. //utility sorter
  176. struct orderingSorter
  177. {
  178. bool operator ()(std::pair<size_t, vector<float>::const_iterator> const& a, std::pair<size_t, vector<float>::const_iterator> const& b)
  179. {
  180. return (*a.second) > (*b.second);
  181. }
  182. };
  183. //data members
  184. string m_vocPath;
  185. string m_vocName;
  186. //string m_resPath;
  187. string m_annotation_path;
  188. string m_image_path;
  189. string m_imageset_path;
  190. string m_class_imageset_path;
  191. vector<string> m_classifier_gt_all_ids;
  192. vector<char> m_classifier_gt_all_present;
  193. string m_classifier_gt_class;
  194. //data members
  195. string m_train_set;
  196. string m_test_set;
  197. vector<string> m_object_classes;
  198. float m_min_overlap;
  199. bool m_sampled_ap;
  200. };
  201. //Return the classification ground truth data for all images of a given VOC object class
  202. //--------------------------------------------------------------------------------------
  203. //INPUTS:
  204. // - obj_class The VOC object class identifier string
  205. // - dataset Specifies whether to extract images from the training or test set
  206. //OUTPUTS:
  207. // - images An array of ObdImage containing info of all images extracted from the ground truth file
  208. // - object_present An array of bools specifying whether the object defined by 'obj_class' is present in each image or not
  209. //NOTES:
  210. // This function is primarily useful for the classification task, where only
  211. // whether a given object is present or not in an image is required, and not each object instance's
  212. // position etc.
  213. void VocData::getClassImages(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present)
  214. {
  215. string dataset_str;
  216. //generate the filename of the classification ground-truth textfile for the object class
  217. if (dataset == CV_OBD_TRAIN)
  218. {
  219. dataset_str = m_train_set;
  220. } else {
  221. dataset_str = m_test_set;
  222. }
  223. getClassImages_impl(obj_class, dataset_str, images, object_present);
  224. }
  225. void VocData::getClassImages_impl(const string& obj_class, const string& dataset_str, vector<ObdImage>& images, vector<char>& object_present)
  226. {
  227. //generate the filename of the classification ground-truth textfile for the object class
  228. string gtFilename = m_class_imageset_path;
  229. gtFilename.replace(gtFilename.find("%s"),2,obj_class);
  230. gtFilename.replace(gtFilename.find("%s"),2,dataset_str);
  231. //parse the ground truth file, storing in two separate vectors
  232. //for the image code and the ground truth value
  233. vector<string> image_codes;
  234. readClassifierGroundTruth(gtFilename, image_codes, object_present);
  235. //prepare output arrays
  236. images.clear();
  237. convertImageCodesToObdImages(image_codes, images);
  238. }
  239. //Return the object data for all images of a given VOC object class
  240. //-----------------------------------------------------------------
  241. //INPUTS:
  242. // - obj_class The VOC object class identifier string
  243. // - dataset Specifies whether to extract images from the training or test set
  244. //OUTPUTS:
  245. // - images An array of ObdImage containing info of all images in chosen dataset (tag, path etc.)
  246. // - objects Contains the extended object info (bounding box etc.) for each object instance in each image
  247. // - object_data Contains VOC-specific extended object info (marked difficult etc.)
  248. // - ground_truth Specifies whether there are any difficult/non-difficult instances of the current
  249. // object class within each image
  250. //NOTES:
  251. // This function returns extended object information in addition to the absent/present
  252. // classification data returned by getClassImages. The objects returned for each image in the 'objects'
  253. // array are of all object classes present in the image, and not just the class defined by 'obj_class'.
  254. // 'ground_truth' can be used to determine quickly whether an object instance of the given class is present
  255. // in an image or not.
  256. void VocData::getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects)
  257. {
  258. vector<vector<VocObjectData> > object_data;
  259. vector<VocGT> ground_truth;
  260. getClassObjects(obj_class,dataset,images,objects,object_data,ground_truth);
  261. }
  262. void VocData::getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects, vector<vector<VocObjectData> >& object_data, vector<VocGT>& ground_truth)
  263. {
  264. //generate the filename of the classification ground-truth textfile for the object class
  265. string gtFilename = m_class_imageset_path;
  266. gtFilename.replace(gtFilename.find("%s"),2,obj_class);
  267. if (dataset == CV_OBD_TRAIN)
  268. {
  269. gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
  270. } else {
  271. gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
  272. }
  273. //parse the ground truth file, storing in two separate vectors
  274. //for the image code and the ground truth value
  275. vector<string> image_codes;
  276. vector<char> object_present;
  277. readClassifierGroundTruth(gtFilename, image_codes, object_present);
  278. //prepare output arrays
  279. images.clear();
  280. objects.clear();
  281. object_data.clear();
  282. ground_truth.clear();
  283. string annotationFilename;
  284. vector<ObdObject> image_objects;
  285. vector<VocObjectData> image_object_data;
  286. VocGT image_gt;
  287. //transfer to output arrays and read in object data for each image
  288. for (size_t i = 0; i < image_codes.size(); ++i)
  289. {
  290. ObdImage image = getObjects(obj_class, image_codes[i], image_objects, image_object_data, image_gt);
  291. images.push_back(image);
  292. objects.push_back(image_objects);
  293. object_data.push_back(image_object_data);
  294. ground_truth.push_back(image_gt);
  295. }
  296. }
  297. //Return ground truth data for the objects present in an image with a given UID
  298. //-----------------------------------------------------------------------------
  299. //INPUTS:
  300. // - id VOC Dataset unique identifier (string code in form YYYY_XXXXXX where YYYY is the year)
  301. //OUTPUTS:
  302. // - obj_class (*3) Specifies the object class to use to resolve 'ground_truth'
  303. // - objects Contains the extended object info (bounding box etc.) for each object in the image
  304. // - object_data (*2,3) Contains VOC-specific extended object info (marked difficult etc.)
  305. // - ground_truth (*3) Specifies whether there are any difficult/non-difficult instances of the current
  306. // object class within the image
  307. //RETURN VALUE:
  308. // ObdImage containing path and other details of image file with given code
  309. //NOTES:
  310. // There are three versions of this function
  311. // * One returns a simple array of objects given an id [1]
  312. // * One returns the same as (1) plus VOC specific object data [2]
  313. // * One returns the same as (2) plus the ground_truth flag. This also requires an extra input obj_class [3]
  314. ObdImage VocData::getObjects(const string& id, vector<ObdObject>& objects)
  315. {
  316. vector<VocObjectData> object_data;
  317. ObdImage image = getObjects(id, objects, object_data);
  318. return image;
  319. }
  320. ObdImage VocData::getObjects(const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data)
  321. {
  322. //first generate the filename of the annotation file
  323. string annotationFilename = m_annotation_path;
  324. annotationFilename.replace(annotationFilename.find("%s"),2,id);
  325. //extract objects contained in the current image from the xml
  326. extractVocObjects(annotationFilename,objects,object_data);
  327. //generate image path from extracted string code
  328. string path = getImagePath(id);
  329. ObdImage image(id, path);
  330. return image;
  331. }
  332. ObdImage VocData::getObjects(const string& obj_class, const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data, VocGT& ground_truth)
  333. {
  334. //extract object data (except for ground truth flag)
  335. ObdImage image = getObjects(id,objects,object_data);
  336. //pregenerate a flag to indicate whether the current class is present or not in the image
  337. ground_truth = CV_VOC_GT_NONE;
  338. //iterate through all objects in current image
  339. for (size_t j = 0; j < objects.size(); ++j)
  340. {
  341. if (objects[j].object_class == obj_class)
  342. {
  343. if (object_data[j].difficult == false)
  344. {
  345. //if at least one non-difficult example is present, this flag is always set to CV_VOC_GT_PRESENT
  346. ground_truth = CV_VOC_GT_PRESENT;
  347. break;
  348. } else {
  349. //set if at least one object instance is present, but it is marked difficult
  350. ground_truth = CV_VOC_GT_DIFFICULT;
  351. }
  352. }
  353. }
  354. return image;
  355. }
  356. //Return ground truth data for the presence/absence of a given object class in an arbitrary array of images
  357. //---------------------------------------------------------------------------------------------------------
  358. //INPUTS:
  359. // - obj_class The VOC object class identifier string
  360. // - images An array of ObdImage OR strings containing the images for which ground truth
  361. // will be computed
  362. //OUTPUTS:
  363. // - ground_truth An output array indicating the presence/absence of obj_class within each image
  364. void VocData::getClassifierGroundTruth(const string& obj_class, const vector<ObdImage>& images, vector<char>& ground_truth)
  365. {
  366. vector<char>(images.size()).swap(ground_truth);
  367. vector<ObdObject> objects;
  368. vector<VocObjectData> object_data;
  369. vector<char>::iterator gt_it = ground_truth.begin();
  370. for (vector<ObdImage>::const_iterator it = images.begin(); it != images.end(); ++it, ++gt_it)
  371. {
  372. //getObjects(obj_class, it->id, objects, object_data, voc_ground_truth);
  373. (*gt_it) = (getClassifierGroundTruthImage(obj_class, it->id));
  374. }
  375. }
  376. void VocData::getClassifierGroundTruth(const string& obj_class, const vector<string>& images, vector<char>& ground_truth)
  377. {
  378. vector<char>(images.size()).swap(ground_truth);
  379. vector<ObdObject> objects;
  380. vector<VocObjectData> object_data;
  381. vector<char>::iterator gt_it = ground_truth.begin();
  382. for (vector<string>::const_iterator it = images.begin(); it != images.end(); ++it, ++gt_it)
  383. {
  384. //getObjects(obj_class, (*it), objects, object_data, voc_ground_truth);
  385. (*gt_it) = (getClassifierGroundTruthImage(obj_class, (*it)));
  386. }
  387. }
  388. //Return ground truth data for the accuracy of detection results
  389. //--------------------------------------------------------------
  390. //INPUTS:
  391. // - obj_class The VOC object class identifier string
  392. // - images An array of ObdImage containing the images for which ground truth
  393. // will be computed
  394. // - bounding_boxes A 2D input array containing the bounding box rects of the objects of
  395. // obj_class which were detected in each image
  396. //OUTPUTS:
  397. // - ground_truth A 2D output array indicating whether each object detection was accurate
  398. // or not
  399. // - detection_difficult A 2D output array indicating whether the detection fired on an object
  400. // marked as 'difficult'. This allows it to be ignored if necessary
  401. // (the voc documentation specifies objects marked as difficult
  402. // have no effects on the results and are effectively ignored)
  403. // - (ignore_difficult) If set to true, objects marked as difficult will be ignored when returning
  404. // the number of hits for p-r normalization (default = true)
  405. //RETURN VALUE:
  406. // Returns the number of object hits in total in the gt to allow proper normalization
  407. // of a p-r curve
  408. //NOTES:
  409. // As stated in the VOC documentation, multiple detections of the same object in an image are
  410. // considered FALSE detections e.g. 5 detections of a single object is counted as 1 correct
  411. // detection and 4 false detections - it is the responsibility of the participant's system
  412. // to filter multiple detections from its output
  413. int VocData::getDetectorGroundTruth(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<Rect> >& bounding_boxes, const vector<vector<float> >& scores, vector<vector<char> >& ground_truth, vector<vector<char> >& detection_difficult, bool ignore_difficult)
  414. {
  415. int recall_normalization = 0;
  416. /* first create a list of indices referring to the elements of bounding_boxes and scores in
  417. * descending order of scores */
  418. vector<ObdScoreIndexSorter> sorted_ids;
  419. {
  420. /* first count how many objects to allow preallocation */
  421. size_t obj_count = 0;
  422. CV_Assert(images.size() == bounding_boxes.size());
  423. CV_Assert(scores.size() == bounding_boxes.size());
  424. for (size_t im_idx = 0; im_idx < scores.size(); ++im_idx)
  425. {
  426. CV_Assert(scores[im_idx].size() == bounding_boxes[im_idx].size());
  427. obj_count += scores[im_idx].size();
  428. }
  429. /* preallocate id vector */
  430. sorted_ids.resize(obj_count);
  431. /* now copy across scores and indexes to preallocated vector */
  432. int flat_pos = 0;
  433. for (size_t im_idx = 0; im_idx < scores.size(); ++im_idx)
  434. {
  435. for (size_t ob_idx = 0; ob_idx < scores[im_idx].size(); ++ob_idx)
  436. {
  437. sorted_ids[flat_pos].score = scores[im_idx][ob_idx];
  438. sorted_ids[flat_pos].image_idx = (int)im_idx;
  439. sorted_ids[flat_pos].obj_idx = (int)ob_idx;
  440. ++flat_pos;
  441. }
  442. }
  443. /* and sort the vector in descending order of score */
  444. std::sort(sorted_ids.begin(),sorted_ids.end());
  445. std::reverse(sorted_ids.begin(),sorted_ids.end());
  446. }
  447. /* prepare ground truth + difficult vector (1st dimension) */
  448. vector<vector<char> >(images.size()).swap(ground_truth);
  449. vector<vector<char> >(images.size()).swap(detection_difficult);
  450. vector<vector<char> > detected(images.size());
  451. vector<vector<ObdObject> > img_objects(images.size());
  452. vector<vector<VocObjectData> > img_object_data(images.size());
  453. /* preload object ground truth bounding box data */
  454. {
  455. vector<vector<ObdObject> > img_objects_all(images.size());
  456. vector<vector<VocObjectData> > img_object_data_all(images.size());
  457. for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
  458. {
  459. /* prepopulate ground truth bounding boxes */
  460. getObjects(images[image_idx].id, img_objects_all[image_idx], img_object_data_all[image_idx]);
  461. /* meanwhile, also set length of target ground truth + difficult vector to same as number of object detections (2nd dimension) */
  462. ground_truth[image_idx].resize(bounding_boxes[image_idx].size());
  463. detection_difficult[image_idx].resize(bounding_boxes[image_idx].size());
  464. }
  465. /* save only instances of the object class concerned */
  466. for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
  467. {
  468. for (size_t obj_idx = 0; obj_idx < img_objects_all[image_idx].size(); ++obj_idx)
  469. {
  470. if (img_objects_all[image_idx][obj_idx].object_class == obj_class)
  471. {
  472. img_objects[image_idx].push_back(img_objects_all[image_idx][obj_idx]);
  473. img_object_data[image_idx].push_back(img_object_data_all[image_idx][obj_idx]);
  474. }
  475. }
  476. detected[image_idx].resize(img_objects[image_idx].size(), false);
  477. }
  478. }
  479. /* calculate the total number of objects in the ground truth for the current dataset */
  480. {
  481. vector<ObdImage> gt_images;
  482. vector<char> gt_object_present;
  483. getClassImages(obj_class, dataset, gt_images, gt_object_present);
  484. for (size_t image_idx = 0; image_idx < gt_images.size(); ++image_idx)
  485. {
  486. vector<ObdObject> gt_img_objects;
  487. vector<VocObjectData> gt_img_object_data;
  488. getObjects(gt_images[image_idx].id, gt_img_objects, gt_img_object_data);
  489. for (size_t obj_idx = 0; obj_idx < gt_img_objects.size(); ++obj_idx)
  490. {
  491. if (gt_img_objects[obj_idx].object_class == obj_class)
  492. {
  493. if ((gt_img_object_data[obj_idx].difficult == false) || (ignore_difficult == false))
  494. ++recall_normalization;
  495. }
  496. }
  497. }
  498. }
  499. #ifdef PR_DEBUG
  500. int printed_count = 0;
  501. #endif
  502. /* now iterate through detections in descending order of score, assigning to ground truth bounding boxes if possible */
  503. for (size_t detect_idx = 0; detect_idx < sorted_ids.size(); ++detect_idx)
  504. {
  505. //read in indexes to make following code easier to read
  506. int im_idx = sorted_ids[detect_idx].image_idx;
  507. int ob_idx = sorted_ids[detect_idx].obj_idx;
  508. //set ground truth for the current object to false by default
  509. ground_truth[im_idx][ob_idx] = false;
  510. detection_difficult[im_idx][ob_idx] = false;
  511. float maxov = -1.0;
  512. bool max_is_difficult = false;
  513. int max_gt_obj_idx = -1;
  514. //-- for each detected object iterate through objects present in the bounding box ground truth --
  515. for (size_t gt_obj_idx = 0; gt_obj_idx < img_objects[im_idx].size(); ++gt_obj_idx)
  516. {
  517. if (detected[im_idx][gt_obj_idx] == false)
  518. {
  519. //check if the detected object and ground truth object overlap by a sufficient margin
  520. float ov = testBoundingBoxesForOverlap(bounding_boxes[im_idx][ob_idx], img_objects[im_idx][gt_obj_idx].boundingBox);
  521. if (ov != -1.0)
  522. {
  523. //if all conditions are met store the overlap score and index (as objects are assigned to the highest scoring match)
  524. if (ov > maxov)
  525. {
  526. maxov = ov;
  527. max_gt_obj_idx = (int)gt_obj_idx;
  528. //store whether the maximum detection is marked as difficult or not
  529. max_is_difficult = (img_object_data[im_idx][gt_obj_idx].difficult);
  530. }
  531. }
  532. }
  533. }
  534. //-- if a match was found, set the ground truth of the current object to true --
  535. if (maxov != -1.0)
  536. {
  537. CV_Assert(max_gt_obj_idx != -1);
  538. ground_truth[im_idx][ob_idx] = true;
  539. //store whether the maximum detection was marked as 'difficult' or not
  540. detection_difficult[im_idx][ob_idx] = max_is_difficult;
  541. //remove the ground truth object so it doesn't match with subsequent detected objects
  542. //** this is the behaviour defined by the voc documentation **
  543. detected[im_idx][max_gt_obj_idx] = true;
  544. }
  545. #ifdef PR_DEBUG
  546. if (printed_count < 10)
  547. {
  548. cout << printed_count << ": id=" << images[im_idx].id << ", score=" << scores[im_idx][ob_idx] << " (" << ob_idx << ") [" << bounding_boxes[im_idx][ob_idx].x << "," <<
  549. bounding_boxes[im_idx][ob_idx].y << "," << bounding_boxes[im_idx][ob_idx].width + bounding_boxes[im_idx][ob_idx].x <<
  550. "," << bounding_boxes[im_idx][ob_idx].height + bounding_boxes[im_idx][ob_idx].y << "] detected=" << ground_truth[im_idx][ob_idx] <<
  551. ", difficult=" << detection_difficult[im_idx][ob_idx] << endl;
  552. ++printed_count;
  553. /* print ground truth */
  554. for (int gt_obj_idx = 0; gt_obj_idx < img_objects[im_idx].size(); ++gt_obj_idx)
  555. {
  556. cout << " GT: [" << img_objects[im_idx][gt_obj_idx].boundingBox.x << "," <<
  557. img_objects[im_idx][gt_obj_idx].boundingBox.y << "," << img_objects[im_idx][gt_obj_idx].boundingBox.width + img_objects[im_idx][gt_obj_idx].boundingBox.x <<
  558. "," << img_objects[im_idx][gt_obj_idx].boundingBox.height + img_objects[im_idx][gt_obj_idx].boundingBox.y << "]";
  559. if (gt_obj_idx == max_gt_obj_idx) cout << " <--- (" << maxov << " overlap)";
  560. cout << endl;
  561. }
  562. }
  563. #endif
  564. }
  565. return recall_normalization;
  566. }
  567. //Write VOC-compliant classifier results file
  568. //-------------------------------------------
  569. //INPUTS:
  570. // - obj_class The VOC object class identifier string
  571. // - dataset Specifies whether working with the training or test set
  572. // - images An array of ObdImage containing the images for which data will be saved to the result file
  573. // - scores A corresponding array of confidence scores given a query
  574. // - (competition) If specified, defines which competition the results are for (see VOC documentation - default 1)
  575. //NOTES:
  576. // The result file path and filename are determined automatically using m_results_directory as a base
  577. void VocData::writeClassifierResultsFile( const string& out_dir, const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<float>& scores, const int competition, const bool overwrite_ifexists)
  578. {
  579. CV_Assert(images.size() == scores.size());
  580. string output_file_base, output_file;
  581. if (dataset == CV_OBD_TRAIN)
  582. {
  583. output_file_base = out_dir + "/comp" + integerToString(competition) + "_cls_" + m_train_set + "_" + obj_class;
  584. } else {
  585. output_file_base = out_dir + "/comp" + integerToString(competition) + "_cls_" + m_test_set + "_" + obj_class;
  586. }
  587. output_file = output_file_base + ".txt";
  588. //check if file exists, and if so create a numbered new file instead
  589. if (overwrite_ifexists == false)
  590. {
  591. struct stat stFileInfo;
  592. if (stat(output_file.c_str(),&stFileInfo) == 0)
  593. {
  594. string output_file_new;
  595. int filenum = 0;
  596. do
  597. {
  598. ++filenum;
  599. output_file_new = output_file_base + "_" + integerToString(filenum);
  600. output_file = output_file_new + ".txt";
  601. } while (stat(output_file.c_str(),&stFileInfo) == 0);
  602. }
  603. }
  604. //output data to file
  605. std::ofstream result_file(output_file.c_str());
  606. if (result_file.is_open())
  607. {
  608. for (size_t i = 0; i < images.size(); ++i)
  609. {
  610. result_file << images[i].id << " " << scores[i] << endl;
  611. }
  612. result_file.close();
  613. } else {
  614. string err_msg = "could not open classifier results file '" + output_file + "' for writing. Before running for the first time, a 'results' subdirectory should be created within the VOC dataset base directory. e.g. if the VOC data is stored in /VOC/VOC2010 then the path /VOC/results must be created.";
  615. CV_Error(Error::StsError,err_msg.c_str());
  616. }
  617. }
  618. //---------------------------------------
  619. //CALCULATE METRICS FROM VOC RESULTS DATA
  620. //---------------------------------------
  621. //Utility function to construct a VOC-standard classification results filename
  622. //----------------------------------------------------------------------------
  623. //INPUTS:
  624. // - obj_class The VOC object class identifier string
  625. // - task Specifies whether to generate a filename for the classification or detection task
  626. // - dataset Specifies whether working with the training or test set
  627. // - (competition) If specified, defines which competition the results are for (see VOC documentation
  628. // default of -1 means this is set to 1 for the classification task and 3 for the detection task)
  629. // - (number) If specified and above 0, defines which of a number of duplicate results file produced for a given set of
  630. // of settings should be used (this number will be added as a postfix to the filename)
  631. //NOTES:
  632. // This is primarily useful for returning the filename of a classification file previously computed using writeClassifierResultsFile
  633. // for example when calling calcClassifierPrecRecall
  634. string VocData::getResultsFilename(const string& obj_class, const VocTask task, const ObdDatasetType dataset, const int competition, const int number)
  635. {
  636. if ((competition < 1) && (competition != -1))
  637. CV_Error(Error::StsBadArg,"competition argument should be a positive non-zero number or -1 to accept the default");
  638. if ((number < 1) && (number != -1))
  639. CV_Error(Error::StsBadArg,"number argument should be a positive non-zero number or -1 to accept the default");
  640. string dset, task_type;
  641. if (dataset == CV_OBD_TRAIN)
  642. {
  643. dset = m_train_set;
  644. } else {
  645. dset = m_test_set;
  646. }
  647. int comp = competition;
  648. if (task == CV_VOC_TASK_CLASSIFICATION)
  649. {
  650. task_type = "cls";
  651. if (comp == -1) comp = 1;
  652. } else {
  653. task_type = "det";
  654. if (comp == -1) comp = 3;
  655. }
  656. stringstream ss;
  657. if (number < 1)
  658. {
  659. ss << "comp" << comp << "_" << task_type << "_" << dset << "_" << obj_class << ".txt";
  660. } else {
  661. ss << "comp" << comp << "_" << task_type << "_" << dset << "_" << obj_class << "_" << number << ".txt";
  662. }
  663. string filename = ss.str();
  664. return filename;
  665. }
  666. //Calculate metrics for classification results
  667. //--------------------------------------------
  668. //INPUTS:
  669. // - ground_truth A vector of booleans determining whether the currently tested class is present in each input image
  670. // - scores A vector containing the similarity score for each input image (higher is more similar)
  671. //OUTPUTS:
  672. // - precision A vector containing the precision calculated at each datapoint of a p-r curve generated from the result set
  673. // - recall A vector containing the recall calculated at each datapoint of a p-r curve generated from the result set
  674. // - ap The ap metric calculated from the result set
  675. // - (ranking) A vector of the same length as 'ground_truth' and 'scores' containing the order of the indices in both of
  676. // these arrays when sorting by the ranking score in descending order
  677. //NOTES:
  678. // The result file path and filename are determined automatically using m_results_directory as a base
  679. void VocData::calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking)
  680. {
  681. vector<char> res_ground_truth;
  682. getClassifierGroundTruth(obj_class, images, res_ground_truth);
  683. calcPrecRecall_impl(res_ground_truth, scores, precision, recall, ap, ranking);
  684. }
  685. void VocData::calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap)
  686. {
  687. vector<char> res_ground_truth;
  688. getClassifierGroundTruth(obj_class, images, res_ground_truth);
  689. vector<size_t> ranking;
  690. calcPrecRecall_impl(res_ground_truth, scores, precision, recall, ap, ranking);
  691. }
  692. //< Overloaded version which accepts VOC classification result file input instead of array of scores/ground truth >
  693. //INPUTS:
  694. // - input_file The path to the VOC standard results file to use for calculating precision/recall
  695. // If a full path is not specified, it is assumed this file is in the VOC standard results directory
  696. // A VOC standard filename can be retrieved (as used by writeClassifierResultsFile) by calling getClassifierResultsFilename
  697. void VocData::calcClassifierPrecRecall(const string& input_file, vector<float>& precision, vector<float>& recall, float& ap, bool outputRankingFile)
  698. {
  699. //read in classification results file
  700. vector<string> res_image_codes;
  701. vector<float> res_scores;
  702. string input_file_std = checkFilenamePathsep(input_file);
  703. readClassifierResultsFile(input_file_std, res_image_codes, res_scores);
  704. //extract the object class and dataset from the results file filename
  705. string class_name, dataset_name;
  706. extractDataFromResultsFilename(input_file_std, class_name, dataset_name);
  707. //generate the ground truth for the images extracted from the results file
  708. vector<char> res_ground_truth;
  709. getClassifierGroundTruth(class_name, res_image_codes, res_ground_truth);
  710. if (outputRankingFile)
  711. {
  712. /* 1. store sorting order by score (descending) in 'order' */
  713. vector<std::pair<size_t, vector<float>::const_iterator> > order(res_scores.size());
  714. size_t n = 0;
  715. for (vector<float>::const_iterator it = res_scores.begin(); it != res_scores.end(); ++it, ++n)
  716. order[n] = make_pair(n, it);
  717. std::sort(order.begin(),order.end(),orderingSorter());
  718. /* 2. save ranking results to text file */
  719. string input_file_std1 = checkFilenamePathsep(input_file);
  720. size_t fnamestart = input_file_std1.rfind("/");
  721. string scoregt_file_str = input_file_std1.substr(0,fnamestart+1) + "scoregt_" + class_name + ".txt";
  722. std::ofstream scoregt_file(scoregt_file_str.c_str());
  723. if (scoregt_file.is_open())
  724. {
  725. for (size_t i = 0; i < res_scores.size(); ++i)
  726. {
  727. scoregt_file << res_image_codes[order[i].first] << " " << res_scores[order[i].first] << " " << res_ground_truth[order[i].first] << endl;
  728. }
  729. scoregt_file.close();
  730. } else {
  731. string err_msg = "could not open scoregt file '" + scoregt_file_str + "' for writing.";
  732. CV_Error(Error::StsError,err_msg.c_str());
  733. }
  734. }
  735. //finally, calculate precision+recall+ap
  736. vector<size_t> ranking;
  737. calcPrecRecall_impl(res_ground_truth,res_scores,precision,recall,ap,ranking);
  738. }
  739. //< Protected implementation of Precision-Recall calculation used by both calcClassifierPrecRecall and calcDetectorPrecRecall >
  740. void VocData::calcPrecRecall_impl(const vector<char>& ground_truth, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking, int recall_normalization)
  741. {
  742. CV_Assert(ground_truth.size() == scores.size());
  743. //add extra element for p-r at 0 recall (in case that first retrieved is positive)
  744. vector<float>(scores.size()+1).swap(precision);
  745. vector<float>(scores.size()+1).swap(recall);
  746. // SORT RESULTS BY THEIR SCORE
  747. /* 1. store sorting order in 'order' */
  748. VocData::getSortOrder(scores, ranking);
  749. #ifdef PR_DEBUG
  750. std::ofstream scoregt_file("D:/pr.txt");
  751. if (scoregt_file.is_open())
  752. {
  753. for (int i = 0; i < scores.size(); ++i)
  754. {
  755. scoregt_file << scores[ranking[i]] << " " << ground_truth[ranking[i]] << endl;
  756. }
  757. scoregt_file.close();
  758. }
  759. #endif
  760. // CALCULATE PRECISION+RECALL
  761. int retrieved_hits = 0;
  762. int recall_norm;
  763. if (recall_normalization != -1)
  764. {
  765. recall_norm = recall_normalization;
  766. } else {
  767. #ifdef CV_CXX11
  768. recall_norm = (int)std::count_if(ground_truth.begin(),ground_truth.end(),
  769. [](const char a) { return a == (char)1; });
  770. #else
  771. recall_norm = (int)std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<char>(),(char)1));
  772. #endif
  773. }
  774. ap = 0;
  775. recall[0] = 0;
  776. for (size_t idx = 0; idx < ground_truth.size(); ++idx)
  777. {
  778. if (ground_truth[ranking[idx]] != 0) ++retrieved_hits;
  779. precision[idx+1] = static_cast<float>(retrieved_hits)/static_cast<float>(idx+1);
  780. recall[idx+1] = static_cast<float>(retrieved_hits)/static_cast<float>(recall_norm);
  781. if (idx == 0)
  782. {
  783. //add further point at 0 recall with the same precision value as the first computed point
  784. precision[idx] = precision[idx+1];
  785. }
  786. if (recall[idx+1] == 1.0)
  787. {
  788. //if recall = 1, then end early as all positive images have been found
  789. recall.resize(idx+2);
  790. precision.resize(idx+2);
  791. break;
  792. }
  793. }
  794. /* ap calculation */
  795. if (m_sampled_ap == false)
  796. {
  797. // FOR VOC2010+ AP IS CALCULATED FROM ALL DATAPOINTS
  798. /* make precision monotonically decreasing for purposes of calculating ap */
  799. vector<float> precision_monot(precision.size());
  800. vector<float>::iterator prec_m_it = precision_monot.begin();
  801. for (vector<float>::iterator prec_it = precision.begin(); prec_it != precision.end(); ++prec_it, ++prec_m_it)
  802. {
  803. vector<float>::iterator max_elem;
  804. max_elem = std::max_element(prec_it,precision.end());
  805. (*prec_m_it) = (*max_elem);
  806. }
  807. /* calculate ap */
  808. for (size_t idx = 0; idx < (recall.size()-1); ++idx)
  809. {
  810. ap += (recall[idx+1] - recall[idx])*precision_monot[idx+1] + //no need to take min of prec - is monotonically decreasing
  811. 0.5f*(recall[idx+1] - recall[idx])*std::abs(precision_monot[idx+1] - precision_monot[idx]);
  812. }
  813. } else {
  814. // FOR BEFORE VOC2010 AP IS CALCULATED BY SAMPLING PRECISION AT RECALL 0.0,0.1,..,1.0
  815. for (float recall_pos = 0.f; recall_pos <= 1.f; recall_pos += 0.1f)
  816. {
  817. //find iterator of the precision corresponding to the first recall >= recall_pos
  818. vector<float>::iterator recall_it = recall.begin();
  819. vector<float>::iterator prec_it = precision.begin();
  820. while ((*recall_it) < recall_pos)
  821. {
  822. ++recall_it;
  823. ++prec_it;
  824. if (recall_it == recall.end()) break;
  825. }
  826. /* if no recall >= recall_pos found, this level of recall is never reached so stop adding to ap */
  827. if (recall_it == recall.end()) break;
  828. /* if the prec_it is valid, compute the max precision at this level of recall or higher */
  829. vector<float>::iterator max_prec = std::max_element(prec_it,precision.end());
  830. ap += (*max_prec)/11;
  831. }
  832. }
  833. }
  834. /* functions for calculating confusion matrix rows */
  835. //Calculate rows of a confusion matrix
  836. //------------------------------------
  837. //INPUTS:
  838. // - obj_class The VOC object class identifier string for the confusion matrix row to compute
  839. // - images An array of ObdImage containing the images to use for the computation
  840. // - scores A corresponding array of confidence scores for the presence of obj_class in each image
  841. // - cond Defines whether to use a cut off point based on recall (CV_VOC_CCOND_RECALL) or score
  842. // (CV_VOC_CCOND_SCORETHRESH) the latter is useful for classifier detections where positive
  843. // values are positive detections and negative values are negative detections
  844. // - threshold Threshold value for cond. In case of CV_VOC_CCOND_RECALL, is proportion recall (e.g. 0.5).
  845. // In the case of CV_VOC_CCOND_SCORETHRESH is the value above which to count results.
  846. //OUTPUTS:
  847. // - output_headers An output vector of object class headers for the confusion matrix row
  848. // - output_values An output vector of values for the confusion matrix row corresponding to the classes
  849. // defined in output_headers
  850. //NOTES:
  851. // The methodology used by the classifier version of this function is that true positives have a single unit
  852. // added to the obj_class column in the confusion matrix row, whereas false positives have a single unit
  853. // distributed in proportion between all the columns in the confusion matrix row corresponding to the objects
  854. // present in the image.
  855. void VocData::calcClassifierConfMatRow(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values)
  856. {
  857. CV_Assert(images.size() == scores.size());
  858. // SORT RESULTS BY THEIR SCORE
  859. /* 1. store sorting order in 'ranking' */
  860. vector<size_t> ranking;
  861. VocData::getSortOrder(scores, ranking);
  862. // CALCULATE CONFUSION MATRIX ENTRIES
  863. /* prepare object category headers */
  864. output_headers = m_object_classes;
  865. vector<float>(output_headers.size(),0.0).swap(output_values);
  866. /* find the index of the target object class in the headers for later use */
  867. int target_idx;
  868. {
  869. vector<string>::iterator target_idx_it = std::find(output_headers.begin(),output_headers.end(),obj_class);
  870. /* if the target class can not be found, raise an exception */
  871. if (target_idx_it == output_headers.end())
  872. {
  873. string err_msg = "could not find the target object class '" + obj_class + "' in list of valid classes.";
  874. CV_Error(Error::StsError,err_msg.c_str());
  875. }
  876. /* convert iterator to index */
  877. target_idx = (int)std::distance(output_headers.begin(),target_idx_it);
  878. }
  879. /* prepare variables related to calculating recall if using the recall threshold */
  880. int retrieved_hits = 0;
  881. int total_relevant = 0;
  882. if (cond == CV_VOC_CCOND_RECALL)
  883. {
  884. vector<char> ground_truth;
  885. /* in order to calculate the total number of relevant images for normalization of recall
  886. it's necessary to extract the ground truth for the images under consideration */
  887. getClassifierGroundTruth(obj_class, images, ground_truth);
  888. #ifdef CV_CXX11
  889. total_relevant = (int)std::count_if(ground_truth.begin(),ground_truth.end(),
  890. [](const char a) { return a == (char)1; });
  891. #else
  892. total_relevant = (int)std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<char>(),(char)1));
  893. #endif
  894. }
  895. /* iterate through images */
  896. vector<ObdObject> img_objects;
  897. vector<VocObjectData> img_object_data;
  898. int total_images = 0;
  899. for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
  900. {
  901. /* if using the score as the break condition, check for it now */
  902. if (cond == CV_VOC_CCOND_SCORETHRESH)
  903. {
  904. if (scores[ranking[image_idx]] <= threshold) break;
  905. }
  906. /* if continuing for this iteration, increment the image counter for later normalization */
  907. ++total_images;
  908. /* for each image retrieve the objects contained */
  909. getObjects(images[ranking[image_idx]].id, img_objects, img_object_data);
  910. //check if the tested for object class is present
  911. if (getClassifierGroundTruthImage(obj_class, images[ranking[image_idx]].id))
  912. {
  913. //if the target class is present, assign fully to the target class element in the confusion matrix row
  914. output_values[target_idx] += 1.0;
  915. if (cond == CV_VOC_CCOND_RECALL) ++retrieved_hits;
  916. } else {
  917. //first delete all objects marked as difficult
  918. for (size_t obj_idx = 0; obj_idx < img_objects.size(); ++obj_idx)
  919. {
  920. if (img_object_data[obj_idx].difficult == true)
  921. {
  922. vector<ObdObject>::iterator it1 = img_objects.begin();
  923. std::advance(it1,obj_idx);
  924. img_objects.erase(it1);
  925. vector<VocObjectData>::iterator it2 = img_object_data.begin();
  926. std::advance(it2,obj_idx);
  927. img_object_data.erase(it2);
  928. --obj_idx;
  929. }
  930. }
  931. //if the target class is not present, add values to the confusion matrix row in equal proportions to all objects present in the image
  932. for (size_t obj_idx = 0; obj_idx < img_objects.size(); ++obj_idx)
  933. {
  934. //find the index of the currently considered object
  935. vector<string>::iterator class_idx_it = std::find(output_headers.begin(),output_headers.end(),img_objects[obj_idx].object_class);
  936. //if the class name extracted from the ground truth file could not be found in the list of available classes, raise an exception
  937. if (class_idx_it == output_headers.end())
  938. {
  939. string err_msg = "could not find object class '" + img_objects[obj_idx].object_class + "' specified in the ground truth file of '" + images[ranking[image_idx]].id +"'in list of valid classes.";
  940. CV_Error(Error::StsError,err_msg.c_str());
  941. }
  942. /* convert iterator to index */
  943. int class_idx = (int)std::distance(output_headers.begin(),class_idx_it);
  944. //add to confusion matrix row in proportion
  945. output_values[class_idx] += 1.f/static_cast<float>(img_objects.size());
  946. }
  947. }
  948. //check break conditions if breaking on certain level of recall
  949. if (cond == CV_VOC_CCOND_RECALL)
  950. {
  951. if(static_cast<float>(retrieved_hits)/static_cast<float>(total_relevant) >= threshold) break;
  952. }
  953. }
  954. /* finally, normalize confusion matrix row */
  955. for (vector<float>::iterator it = output_values.begin(); it < output_values.end(); ++it)
  956. {
  957. (*it) /= static_cast<float>(total_images);
  958. }
  959. }
  960. // NOTE: doesn't ignore repeated detections
  961. void VocData::calcDetectorConfMatRow(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<float> >& scores, const vector<vector<Rect> >& bounding_boxes, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values, bool ignore_difficult)
  962. {
  963. CV_Assert(images.size() == scores.size());
  964. CV_Assert(images.size() == bounding_boxes.size());
  965. //collapse scores and ground_truth vectors into 1D vectors to allow ranking
  966. /* define final flat vectors */
  967. vector<string> images_flat;
  968. vector<float> scores_flat;
  969. vector<Rect> bounding_boxes_flat;
  970. {
  971. /* first count how many objects to allow preallocation */
  972. int obj_count = 0;
  973. CV_Assert(scores.size() == bounding_boxes.size());
  974. for (size_t img_idx = 0; img_idx < scores.size(); ++img_idx)
  975. {
  976. CV_Assert(scores[img_idx].size() == bounding_boxes[img_idx].size());
  977. for (size_t obj_idx = 0; obj_idx < scores[img_idx].size(); ++obj_idx)
  978. {
  979. ++obj_count;
  980. }
  981. }
  982. /* preallocate vectors */
  983. images_flat.resize(obj_count);
  984. scores_flat.resize(obj_count);
  985. bounding_boxes_flat.resize(obj_count);
  986. /* now copy across to preallocated vectors */
  987. int flat_pos = 0;
  988. for (size_t img_idx = 0; img_idx < scores.size(); ++img_idx)
  989. {
  990. for (size_t obj_idx = 0; obj_idx < scores[img_idx].size(); ++obj_idx)
  991. {
  992. images_flat[flat_pos] = images[img_idx].id;
  993. scores_flat[flat_pos] = scores[img_idx][obj_idx];
  994. bounding_boxes_flat[flat_pos] = bounding_boxes[img_idx][obj_idx];
  995. ++flat_pos;
  996. }
  997. }
  998. }
  999. // SORT RESULTS BY THEIR SCORE
  1000. /* 1. store sorting order in 'ranking' */
  1001. vector<size_t> ranking;
  1002. VocData::getSortOrder(scores_flat, ranking);
  1003. // CALCULATE CONFUSION MATRIX ENTRIES
  1004. /* prepare object category headers */
  1005. output_headers = m_object_classes;
  1006. output_headers.push_back("background");
  1007. vector<float>(output_headers.size(),0.0).swap(output_values);
  1008. /* prepare variables related to calculating recall if using the recall threshold */
  1009. int retrieved_hits = 0;
  1010. int total_relevant = 0;
  1011. if (cond == CV_VOC_CCOND_RECALL)
  1012. {
  1013. // vector<char> ground_truth;
  1014. // /* in order to calculate the total number of relevant images for normalization of recall
  1015. // it's necessary to extract the ground truth for the images under consideration */
  1016. // getClassifierGroundTruth(obj_class, images, ground_truth);
  1017. // total_relevant = std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<bool>(),true));
  1018. /* calculate the total number of objects in the ground truth for the current dataset */
  1019. vector<ObdImage> gt_images;
  1020. vector<char> gt_object_present;
  1021. getClassImages(obj_class, dataset, gt_images, gt_object_present);
  1022. for (size_t image_idx = 0; image_idx < gt_images.size(); ++image_idx)
  1023. {
  1024. vector<ObdObject> gt_img_objects;
  1025. vector<VocObjectData> gt_img_object_data;
  1026. getObjects(gt_images[image_idx].id, gt_img_objects, gt_img_object_data);
  1027. for (size_t obj_idx = 0; obj_idx < gt_img_objects.size(); ++obj_idx)
  1028. {
  1029. if (gt_img_objects[obj_idx].object_class == obj_class)
  1030. {
  1031. if ((gt_img_object_data[obj_idx].difficult == false) || (ignore_difficult == false))
  1032. ++total_relevant;
  1033. }
  1034. }
  1035. }
  1036. }
  1037. /* iterate through objects */
  1038. vector<ObdObject> img_objects;
  1039. vector<VocObjectData> img_object_data;
  1040. int total_objects = 0;
  1041. for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
  1042. {
  1043. /* if using the score as the break condition, check for it now */
  1044. if (cond == CV_VOC_CCOND_SCORETHRESH)
  1045. {
  1046. if (scores_flat[ranking[image_idx]] <= threshold) break;
  1047. }
  1048. /* increment the image counter for later normalization */
  1049. ++total_objects;
  1050. /* for each image retrieve the objects contained */
  1051. getObjects(images[ranking[image_idx]].id, img_objects, img_object_data);
  1052. //find the ground truth object which has the highest overlap score with the detected object
  1053. float maxov = -1.0;
  1054. int max_gt_obj_idx = -1;
  1055. //-- for each detected object iterate through objects present in ground truth --
  1056. for (size_t gt_obj_idx = 0; gt_obj_idx < img_objects.size(); ++gt_obj_idx)
  1057. {
  1058. //check difficulty flag
  1059. if (ignore_difficult || (img_object_data[gt_obj_idx].difficult == false))
  1060. {
  1061. //if the class matches, then check if the detected object and ground truth object overlap by a sufficient margin
  1062. float ov = testBoundingBoxesForOverlap(bounding_boxes_flat[ranking[image_idx]], img_objects[gt_obj_idx].boundingBox);
  1063. if (ov != -1.f)
  1064. {
  1065. //if all conditions are met store the overlap score and index (as objects are assigned to the highest scoring match)
  1066. if (ov > maxov)
  1067. {
  1068. maxov = ov;
  1069. max_gt_obj_idx = (int)gt_obj_idx;
  1070. }
  1071. }
  1072. }
  1073. }
  1074. //assign to appropriate object class if an object was detected
  1075. if (maxov != -1.0)
  1076. {
  1077. //find the index of the currently considered object
  1078. vector<string>::iterator class_idx_it = std::find(output_headers.begin(),output_headers.end(),img_objects[max_gt_obj_idx].object_class);
  1079. //if the class name extracted from the ground truth file could not be found in the list of available classes, raise an exception
  1080. if (class_idx_it == output_headers.end())
  1081. {
  1082. string err_msg = "could not find object class '" + img_objects[max_gt_obj_idx].object_class + "' specified in the ground truth file of '" + images[ranking[image_idx]].id +"'in list of valid classes.";
  1083. CV_Error(Error::StsError,err_msg.c_str());
  1084. }
  1085. /* convert iterator to index */
  1086. int class_idx = (int)std::distance(output_headers.begin(),class_idx_it);
  1087. //add to confusion matrix row in proportion
  1088. output_values[class_idx] += 1.0;
  1089. } else {
  1090. //otherwise assign to background class
  1091. output_values[output_values.size()-1] += 1.0;
  1092. }
  1093. //check break conditions if breaking on certain level of recall
  1094. if (cond == CV_VOC_CCOND_RECALL)
  1095. {
  1096. if(static_cast<float>(retrieved_hits)/static_cast<float>(total_relevant) >= threshold) break;
  1097. }
  1098. }
  1099. /* finally, normalize confusion matrix row */
  1100. for (vector<float>::iterator it = output_values.begin(); it < output_values.end(); ++it)
  1101. {
  1102. (*it) /= static_cast<float>(total_objects);
  1103. }
  1104. }
  1105. //Save Precision-Recall results to a p-r curve in GNUPlot format
  1106. //--------------------------------------------------------------
  1107. //INPUTS:
  1108. // - output_file The file to which to save the GNUPlot data file. If only a filename is specified, the data
  1109. // file is saved to the standard VOC results directory.
  1110. // - precision Vector of precisions as returned from calcClassifier/DetectorPrecRecall
  1111. // - recall Vector of recalls as returned from calcClassifier/DetectorPrecRecall
  1112. // - ap ap as returned from calcClassifier/DetectorPrecRecall
  1113. // - (title) Title to use for the plot (if not specified, just the ap is printed as the title)
  1114. // This also specifies the filename of the output file if printing to pdf
  1115. // - (plot_type) Specifies whether to instruct GNUPlot to save to a PDF file (CV_VOC_PLOT_PDF) or directly
  1116. // to screen (CV_VOC_PLOT_SCREEN) in the datafile
  1117. //NOTES:
  1118. // The GNUPlot data file can be executed using GNUPlot from the commandline in the following way:
  1119. // >> GNUPlot <output_file>
  1120. // This will then display the p-r curve on the screen or save it to a pdf file depending on plot_type
  1121. void VocData::savePrecRecallToGnuplot(const string& output_file, const vector<float>& precision, const vector<float>& recall, const float ap, const string title, const VocPlotType plot_type)
  1122. {
  1123. string output_file_std = checkFilenamePathsep(output_file);
  1124. //if no directory is specified, by default save the output file in the results directory
  1125. // if (output_file_std.find("/") == output_file_std.npos)
  1126. // {
  1127. // output_file_std = m_results_directory + output_file_std;
  1128. // }
  1129. std::ofstream plot_file(output_file_std.c_str());
  1130. if (plot_file.is_open())
  1131. {
  1132. plot_file << "set xrange [0:1]" << endl;
  1133. plot_file << "set yrange [0:1]" << endl;
  1134. plot_file << "set size square" << endl;
  1135. string title_text = title;
  1136. if (title_text.size() == 0) title_text = "Precision-Recall Curve";
  1137. plot_file << "set title \"" << title_text << " (ap: " << ap << ")\"" << endl;
  1138. plot_file << "set xlabel \"Recall\"" << endl;
  1139. plot_file << "set ylabel \"Precision\"" << endl;
  1140. plot_file << "set style data lines" << endl;
  1141. plot_file << "set nokey" << endl;
  1142. if (plot_type == CV_VOC_PLOT_PNG)
  1143. {
  1144. plot_file << "set terminal png" << endl;
  1145. string pdf_filename;
  1146. if (title.size() != 0)
  1147. {
  1148. pdf_filename = title;
  1149. } else {
  1150. pdf_filename = "prcurve";
  1151. }
  1152. plot_file << "set out \"" << title << ".png\"" << endl;
  1153. }
  1154. plot_file << "plot \"-\" using 1:2" << endl;
  1155. plot_file << "# X Y" << endl;
  1156. CV_Assert(precision.size() == recall.size());
  1157. for (size_t i = 0; i < precision.size(); ++i)
  1158. {
  1159. plot_file << " " << recall[i] << " " << precision[i] << endl;
  1160. }
  1161. plot_file << "end" << endl;
  1162. if (plot_type == CV_VOC_PLOT_SCREEN)
  1163. {
  1164. plot_file << "pause -1" << endl;
  1165. }
  1166. plot_file.close();
  1167. } else {
  1168. string err_msg = "could not open plot file '" + output_file_std + "' for writing.";
  1169. CV_Error(Error::StsError,err_msg.c_str());
  1170. }
  1171. }
  1172. void VocData::readClassifierGroundTruth(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present)
  1173. {
  1174. images.clear();
  1175. string gtFilename = m_class_imageset_path;
  1176. gtFilename.replace(gtFilename.find("%s"),2,obj_class);
  1177. if (dataset == CV_OBD_TRAIN)
  1178. {
  1179. gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
  1180. } else {
  1181. gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
  1182. }
  1183. vector<string> image_codes;
  1184. readClassifierGroundTruth(gtFilename, image_codes, object_present);
  1185. convertImageCodesToObdImages(image_codes, images);
  1186. }
  1187. void VocData::readClassifierResultsFile(const std:: string& input_file, vector<ObdImage>& images, vector<float>& scores)
  1188. {
  1189. images.clear();
  1190. string input_file_std = checkFilenamePathsep(input_file);
  1191. //if no directory is specified, by default search for the input file in the results directory
  1192. // if (input_file_std.find("/") == input_file_std.npos)
  1193. // {
  1194. // input_file_std = m_results_directory + input_file_std;
  1195. // }
  1196. vector<string> image_codes;
  1197. readClassifierResultsFile(input_file_std, image_codes, scores);
  1198. convertImageCodesToObdImages(image_codes, images);
  1199. }
  1200. void VocData::readDetectorResultsFile(const string& input_file, vector<ObdImage>& images, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes)
  1201. {
  1202. images.clear();
  1203. string input_file_std = checkFilenamePathsep(input_file);
  1204. //if no directory is specified, by default search for the input file in the results directory
  1205. // if (input_file_std.find("/") == input_file_std.npos)
  1206. // {
  1207. // input_file_std = m_results_directory + input_file_std;
  1208. // }
  1209. vector<string> image_codes;
  1210. readDetectorResultsFile(input_file_std, image_codes, scores, bounding_boxes);
  1211. convertImageCodesToObdImages(image_codes, images);
  1212. }
  1213. const vector<string>& VocData::getObjectClasses()
  1214. {
  1215. return m_object_classes;
  1216. }
  1217. //string VocData::getResultsDirectory()
  1218. //{
  1219. // return m_results_directory;
  1220. //}
  1221. //---------------------------------------------------------
  1222. // Protected Functions ------------------------------------
  1223. //---------------------------------------------------------
  1224. static string getVocName( const string& vocPath )
  1225. {
  1226. size_t found = vocPath.rfind( '/' );
  1227. if( found == string::npos )
  1228. {
  1229. found = vocPath.rfind( '\\' );
  1230. if( found == string::npos )
  1231. return vocPath;
  1232. }
  1233. return vocPath.substr(found + 1, vocPath.size() - found);
  1234. }
  1235. void VocData::initVoc( const string& vocPath, const bool useTestDataset )
  1236. {
  1237. initVoc2007to2010( vocPath, useTestDataset );
  1238. }
  1239. //Initialize file paths and settings for the VOC 2010 dataset
  1240. //-----------------------------------------------------------
  1241. void VocData::initVoc2007to2010( const string& vocPath, const bool useTestDataset )
  1242. {
  1243. //check format of root directory and modify if necessary
  1244. m_vocName = getVocName( vocPath );
  1245. CV_Assert( !m_vocName.compare("VOC2007") || !m_vocName.compare("VOC2008") ||
  1246. !m_vocName.compare("VOC2009") || !m_vocName.compare("VOC2010") );
  1247. m_vocPath = checkFilenamePathsep( vocPath, true );
  1248. if (useTestDataset)
  1249. {
  1250. m_train_set = "trainval";
  1251. m_test_set = "test";
  1252. } else {
  1253. m_train_set = "train";
  1254. m_test_set = "val";
  1255. }
  1256. // initialize main classification/detection challenge paths
  1257. m_annotation_path = m_vocPath + "/Annotations/%s.xml";
  1258. m_image_path = m_vocPath + "/JPEGImages/%s.jpg";
  1259. m_imageset_path = m_vocPath + "/ImageSets/Main/%s.txt";
  1260. m_class_imageset_path = m_vocPath + "/ImageSets/Main/%s_%s.txt";
  1261. //define available object_classes for VOC2010 dataset
  1262. m_object_classes.push_back("aeroplane");
  1263. m_object_classes.push_back("bicycle");
  1264. m_object_classes.push_back("bird");
  1265. m_object_classes.push_back("boat");
  1266. m_object_classes.push_back("bottle");
  1267. m_object_classes.push_back("bus");
  1268. m_object_classes.push_back("car");
  1269. m_object_classes.push_back("cat");
  1270. m_object_classes.push_back("chair");
  1271. m_object_classes.push_back("cow");
  1272. m_object_classes.push_back("diningtable");
  1273. m_object_classes.push_back("dog");
  1274. m_object_classes.push_back("horse");
  1275. m_object_classes.push_back("motorbike");
  1276. m_object_classes.push_back("person");
  1277. m_object_classes.push_back("pottedplant");
  1278. m_object_classes.push_back("sheep");
  1279. m_object_classes.push_back("sofa");
  1280. m_object_classes.push_back("train");
  1281. m_object_classes.push_back("tvmonitor");
  1282. m_min_overlap = 0.5;
  1283. //up until VOC 2010, ap was calculated by sampling p-r curve, not taking complete curve
  1284. m_sampled_ap = ((m_vocName == "VOC2007") || (m_vocName == "VOC2008") || (m_vocName == "VOC2009"));
  1285. }
  1286. //Read a VOC classification ground truth text file for a given object class and dataset
  1287. //-------------------------------------------------------------------------------------
  1288. //INPUTS:
  1289. // - filename The path of the text file to read
  1290. //OUTPUTS:
  1291. // - image_codes VOC image codes extracted from the GT file in the form 20XX_XXXXXX where the first four
  1292. // digits specify the year of the dataset, and the last group specifies a unique ID
  1293. // - object_present For each image in the 'image_codes' array, specifies whether the object class described
  1294. // in the loaded GT file is present or not
  1295. void VocData::readClassifierGroundTruth(const string& filename, vector<string>& image_codes, vector<char>& object_present)
  1296. {
  1297. image_codes.clear();
  1298. object_present.clear();
  1299. std::ifstream gtfile(filename.c_str());
  1300. if (!gtfile.is_open())
  1301. {
  1302. string err_msg = "could not open VOC ground truth textfile '" + filename + "'.";
  1303. CV_Error(Error::StsError,err_msg.c_str());
  1304. }
  1305. string line;
  1306. string image;
  1307. int obj_present = 0;
  1308. while (!gtfile.eof())
  1309. {
  1310. std::getline(gtfile,line);
  1311. std::istringstream iss(line);
  1312. iss >> image >> obj_present;
  1313. if (!iss.fail())
  1314. {
  1315. image_codes.push_back(image);
  1316. object_present.push_back(obj_present == 1);
  1317. } else {
  1318. if (!gtfile.eof()) CV_Error(Error::StsParseError,"error parsing VOC ground truth textfile.");
  1319. }
  1320. }
  1321. gtfile.close();
  1322. }
  1323. void VocData::readClassifierResultsFile(const string& input_file, vector<string>& image_codes, vector<float>& scores)
  1324. {
  1325. //check if results file exists
  1326. std::ifstream result_file(input_file.c_str());
  1327. if (result_file.is_open())
  1328. {
  1329. string line;
  1330. string image;
  1331. float score;
  1332. //read in the results file
  1333. while (!result_file.eof())
  1334. {
  1335. std::getline(result_file,line);
  1336. std::istringstream iss(line);
  1337. iss >> image >> score;
  1338. if (!iss.fail())
  1339. {
  1340. image_codes.push_back(image);
  1341. scores.push_back(score);
  1342. } else {
  1343. if(!result_file.eof()) CV_Error(Error::StsParseError,"error parsing VOC classifier results file.");
  1344. }
  1345. }
  1346. result_file.close();
  1347. } else {
  1348. string err_msg = "could not open classifier results file '" + input_file + "' for reading.";
  1349. CV_Error(Error::StsError,err_msg.c_str());
  1350. }
  1351. }
  1352. void VocData::readDetectorResultsFile(const string& input_file, vector<string>& image_codes, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes)
  1353. {
  1354. image_codes.clear();
  1355. scores.clear();
  1356. bounding_boxes.clear();
  1357. //check if results file exists
  1358. std::ifstream result_file(input_file.c_str());
  1359. if (result_file.is_open())
  1360. {
  1361. string line;
  1362. string image;
  1363. Rect bounding_box;
  1364. float score;
  1365. //read in the results file
  1366. while (!result_file.eof())
  1367. {
  1368. std::getline(result_file,line);
  1369. std::istringstream iss(line);
  1370. iss >> image >> score >> bounding_box.x >> bounding_box.y >> bounding_box.width >> bounding_box.height;
  1371. if (!iss.fail())
  1372. {
  1373. //convert right and bottom positions to width and height
  1374. bounding_box.width -= bounding_box.x;
  1375. bounding_box.height -= bounding_box.y;
  1376. //convert to 0-indexing
  1377. bounding_box.x -= 1;
  1378. bounding_box.y -= 1;
  1379. //store in output vectors
  1380. /* first check if the current image code has been seen before */
  1381. vector<string>::iterator image_codes_it = std::find(image_codes.begin(),image_codes.end(),image);
  1382. if (image_codes_it == image_codes.end())
  1383. {
  1384. image_codes.push_back(image);
  1385. vector<float> score_vect(1);
  1386. score_vect[0] = score;
  1387. scores.push_back(score_vect);
  1388. vector<Rect> bounding_box_vect(1);
  1389. bounding_box_vect[0] = bounding_box;
  1390. bounding_boxes.push_back(bounding_box_vect);
  1391. } else {
  1392. /* if the image index has been seen before, add the current object below it in the 2D arrays */
  1393. int image_idx = (int)std::distance(image_codes.begin(),image_codes_it);
  1394. scores[image_idx].push_back(score);
  1395. bounding_boxes[image_idx].push_back(bounding_box);
  1396. }
  1397. } else {
  1398. if(!result_file.eof()) CV_Error(Error::StsParseError,"error parsing VOC detector results file.");
  1399. }
  1400. }
  1401. result_file.close();
  1402. } else {
  1403. string err_msg = "could not open detector results file '" + input_file + "' for reading.";
  1404. CV_Error(Error::StsError,err_msg.c_str());
  1405. }
  1406. }
  1407. //Read a VOC annotation xml file for a given image
  1408. //------------------------------------------------
  1409. //INPUTS:
  1410. // - filename The path of the xml file to read
  1411. //OUTPUTS:
  1412. // - objects Array of VocObject describing all object instances present in the given image
  1413. void VocData::extractVocObjects(const string filename, vector<ObdObject>& objects, vector<VocObjectData>& object_data)
  1414. {
  1415. #ifdef PR_DEBUG
  1416. int block = 1;
  1417. cout << "SAMPLE VOC OBJECT EXTRACTION for " << filename << ":" << endl;
  1418. #endif
  1419. objects.clear();
  1420. object_data.clear();
  1421. string contents, object_contents, tag_contents;
  1422. readFileToString(filename, contents);
  1423. //keep on extracting 'object' blocks until no more can be found
  1424. if (extractXMLBlock(contents, "annotation", 0, contents) != -1)
  1425. {
  1426. int searchpos = 0;
  1427. searchpos = extractXMLBlock(contents, "object", searchpos, object_contents);
  1428. while (searchpos != -1)
  1429. {
  1430. #ifdef PR_DEBUG
  1431. cout << "SEARCHPOS:" << searchpos << endl;
  1432. cout << "start block " << block << " ---------" << endl;
  1433. cout << object_contents << endl;
  1434. cout << "end block " << block << " -----------" << endl;
  1435. ++block;
  1436. #endif
  1437. ObdObject object;
  1438. VocObjectData object_d;
  1439. //object class -------------
  1440. if (extractXMLBlock(object_contents, "name", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <name> tag in object definition of '" + filename + "'");
  1441. object.object_class.swap(tag_contents);
  1442. //object bounding box -------------
  1443. int xmax, xmin, ymax, ymin;
  1444. if (extractXMLBlock(object_contents, "xmax", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <xmax> tag in object definition of '" + filename + "'");
  1445. xmax = stringToInteger(tag_contents);
  1446. if (extractXMLBlock(object_contents, "xmin", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <xmin> tag in object definition of '" + filename + "'");
  1447. xmin = stringToInteger(tag_contents);
  1448. if (extractXMLBlock(object_contents, "ymax", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <ymax> tag in object definition of '" + filename + "'");
  1449. ymax = stringToInteger(tag_contents);
  1450. if (extractXMLBlock(object_contents, "ymin", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <ymin> tag in object definition of '" + filename + "'");
  1451. ymin = stringToInteger(tag_contents);
  1452. object.boundingBox.x = xmin-1; //convert to 0-based indexing
  1453. object.boundingBox.width = xmax - xmin;
  1454. object.boundingBox.y = ymin-1;
  1455. object.boundingBox.height = ymax - ymin;
  1456. CV_Assert(xmin != 0);
  1457. CV_Assert(xmax > xmin);
  1458. CV_Assert(ymin != 0);
  1459. CV_Assert(ymax > ymin);
  1460. //object tags -------------
  1461. if (extractXMLBlock(object_contents, "difficult", 0, tag_contents) != -1)
  1462. {
  1463. object_d.difficult = (tag_contents == "1");
  1464. } else object_d.difficult = false;
  1465. if (extractXMLBlock(object_contents, "occluded", 0, tag_contents) != -1)
  1466. {
  1467. object_d.occluded = (tag_contents == "1");
  1468. } else object_d.occluded = false;
  1469. if (extractXMLBlock(object_contents, "truncated", 0, tag_contents) != -1)
  1470. {
  1471. object_d.truncated = (tag_contents == "1");
  1472. } else object_d.truncated = false;
  1473. if (extractXMLBlock(object_contents, "pose", 0, tag_contents) != -1)
  1474. {
  1475. if (tag_contents == "Frontal") object_d.pose = CV_VOC_POSE_FRONTAL;
  1476. if (tag_contents == "Rear") object_d.pose = CV_VOC_POSE_REAR;
  1477. if (tag_contents == "Left") object_d.pose = CV_VOC_POSE_LEFT;
  1478. if (tag_contents == "Right") object_d.pose = CV_VOC_POSE_RIGHT;
  1479. }
  1480. //add to array of objects
  1481. objects.push_back(object);
  1482. object_data.push_back(object_d);
  1483. //extract next 'object' block from file if it exists
  1484. searchpos = extractXMLBlock(contents, "object", searchpos, object_contents);
  1485. }
  1486. }
  1487. }
  1488. //Converts an image identifier string in the format YYYY_XXXXXX to a single index integer of form XXXXXXYYYY
  1489. //where Y represents a year and returns the image path
  1490. //----------------------------------------------------------------------------------------------------------
  1491. string VocData::getImagePath(const string& input_str)
  1492. {
  1493. string path = m_image_path;
  1494. path.replace(path.find("%s"),2,input_str);
  1495. return path;
  1496. }
  1497. //Tests two boundary boxes for overlap (using the intersection over union metric) and returns the overlap if the objects
  1498. //defined by the two bounding boxes are considered to be matched according to the criterion outlined in
  1499. //the VOC documentation [namely intersection/union > some threshold] otherwise returns -1.0 (no match)
  1500. //----------------------------------------------------------------------------------------------------------
  1501. float VocData::testBoundingBoxesForOverlap(const Rect detection, const Rect ground_truth)
  1502. {
  1503. int detection_x2 = detection.x + detection.width;
  1504. int detection_y2 = detection.y + detection.height;
  1505. int ground_truth_x2 = ground_truth.x + ground_truth.width;
  1506. int ground_truth_y2 = ground_truth.y + ground_truth.height;
  1507. //first calculate the boundaries of the intersection of the rectangles
  1508. int intersection_x = std::max(detection.x, ground_truth.x); //rightmost left
  1509. int intersection_y = std::max(detection.y, ground_truth.y); //bottommost top
  1510. int intersection_x2 = std::min(detection_x2, ground_truth_x2); //leftmost right
  1511. int intersection_y2 = std::min(detection_y2, ground_truth_y2); //topmost bottom
  1512. //then calculate the width and height of the intersection rect
  1513. int intersection_width = intersection_x2 - intersection_x + 1;
  1514. int intersection_height = intersection_y2 - intersection_y + 1;
  1515. //if there is no overlap then return false straight away
  1516. if ((intersection_width <= 0) || (intersection_height <= 0)) return -1.0;
  1517. //otherwise calculate the intersection
  1518. int intersection_area = intersection_width*intersection_height;
  1519. //now calculate the union
  1520. int union_area = (detection.width+1)*(detection.height+1) + (ground_truth.width+1)*(ground_truth.height+1) - intersection_area;
  1521. //calculate the intersection over union and use as threshold as per VOC documentation
  1522. float overlap = static_cast<float>(intersection_area)/static_cast<float>(union_area);
  1523. if (overlap > m_min_overlap)
  1524. {
  1525. return overlap;
  1526. } else {
  1527. return -1.0;
  1528. }
  1529. }
  1530. //Extracts the object class and dataset from the filename of a VOC standard results text file, which takes
  1531. //the format 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'
  1532. //----------------------------------------------------------------------------------------------------------
  1533. void VocData::extractDataFromResultsFilename(const string& input_file, string& class_name, string& dataset_name)
  1534. {
  1535. string input_file_std = checkFilenamePathsep(input_file);
  1536. size_t fnamestart = input_file_std.rfind("/");
  1537. size_t fnameend = input_file_std.rfind(".txt");
  1538. if ((fnamestart == input_file_std.npos) || (fnameend == input_file_std.npos))
  1539. CV_Error(Error::StsError,"Could not extract filename of results file.");
  1540. ++fnamestart;
  1541. if (fnamestart >= fnameend)
  1542. CV_Error(Error::StsError,"Could not extract filename of results file.");
  1543. //extract dataset and class names, triggering exception if the filename format is not correct
  1544. string filename = input_file_std.substr(fnamestart, fnameend-fnamestart);
  1545. size_t datasetstart = filename.find("_");
  1546. datasetstart = filename.find("_",datasetstart+1);
  1547. size_t classstart = filename.find("_",datasetstart+1);
  1548. //allow for appended index after a further '_' by discarding this part if it exists
  1549. size_t classend = filename.find("_",classstart+1);
  1550. if (classend == filename.npos) classend = filename.size();
  1551. if ((datasetstart == filename.npos) || (classstart == filename.npos))
  1552. CV_Error(Error::StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
  1553. ++datasetstart;
  1554. ++classstart;
  1555. if (((datasetstart-classstart) < 1) || ((classend-datasetstart) < 1))
  1556. CV_Error(Error::StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
  1557. dataset_name = filename.substr(datasetstart,classstart-datasetstart-1);
  1558. class_name = filename.substr(classstart,classend-classstart);
  1559. }
  1560. bool VocData::getClassifierGroundTruthImage(const string& obj_class, const string& id)
  1561. {
  1562. /* if the classifier ground truth data for all images of the current class has not been loaded yet, load it now */
  1563. if (m_classifier_gt_all_ids.empty() || (m_classifier_gt_class != obj_class))
  1564. {
  1565. m_classifier_gt_all_ids.clear();
  1566. m_classifier_gt_all_present.clear();
  1567. m_classifier_gt_class = obj_class;
  1568. for (int i=0; i<2; ++i) //run twice (once over test set and once over training set)
  1569. {
  1570. //generate the filename of the classification ground-truth textfile for the object class
  1571. string gtFilename = m_class_imageset_path;
  1572. gtFilename.replace(gtFilename.find("%s"),2,obj_class);
  1573. if (i == 0)
  1574. {
  1575. gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
  1576. } else {
  1577. gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
  1578. }
  1579. //parse the ground truth file, storing in two separate vectors
  1580. //for the image code and the ground truth value
  1581. vector<string> image_codes;
  1582. vector<char> object_present;
  1583. readClassifierGroundTruth(gtFilename, image_codes, object_present);
  1584. m_classifier_gt_all_ids.insert(m_classifier_gt_all_ids.end(),image_codes.begin(),image_codes.end());
  1585. m_classifier_gt_all_present.insert(m_classifier_gt_all_present.end(),object_present.begin(),object_present.end());
  1586. CV_Assert(m_classifier_gt_all_ids.size() == m_classifier_gt_all_present.size());
  1587. }
  1588. }
  1589. //search for the image code
  1590. vector<string>::iterator it = find (m_classifier_gt_all_ids.begin(), m_classifier_gt_all_ids.end(), id);
  1591. if (it != m_classifier_gt_all_ids.end())
  1592. {
  1593. //image found, so return corresponding ground truth
  1594. return m_classifier_gt_all_present[std::distance(m_classifier_gt_all_ids.begin(),it)] != 0;
  1595. }
  1596. string err_msg = "could not find classifier ground truth for image '" + id + "' and class '" + obj_class + "'";
  1597. CV_Error(Error::StsError,err_msg.c_str());
  1598. }
  1599. //-------------------------------------------------------------------
  1600. // Protected Functions (utility) ------------------------------------
  1601. //-------------------------------------------------------------------
  1602. //returns a vector containing indexes of the input vector in sorted ascending/descending order
  1603. void VocData::getSortOrder(const vector<float>& values, vector<size_t>& order, bool descending)
  1604. {
  1605. /* 1. store sorting order in 'order_pair' */
  1606. vector<std::pair<size_t, vector<float>::const_iterator> > order_pair(values.size());
  1607. size_t n = 0;
  1608. for (vector<float>::const_iterator it = values.begin(); it != values.end(); ++it, ++n)
  1609. order_pair[n] = make_pair(n, it);
  1610. std::sort(order_pair.begin(),order_pair.end(),orderingSorter());
  1611. if (descending == false) std::reverse(order_pair.begin(),order_pair.end());
  1612. vector<size_t>(order_pair.size()).swap(order);
  1613. for (size_t i = 0; i < order_pair.size(); ++i)
  1614. {
  1615. order[i] = order_pair[i].first;
  1616. }
  1617. }
  1618. void VocData::readFileToString(const string filename, string& file_contents)
  1619. {
  1620. std::ifstream ifs(filename.c_str());
  1621. if (!ifs.is_open()) CV_Error(Error::StsError,"could not open text file");
  1622. stringstream oss;
  1623. oss << ifs.rdbuf();
  1624. file_contents = oss.str();
  1625. }
  1626. int VocData::stringToInteger(const string input_str)
  1627. {
  1628. int result = 0;
  1629. stringstream ss(input_str);
  1630. if ((ss >> result).fail())
  1631. {
  1632. CV_Error(Error::StsBadArg,"could not perform string to integer conversion");
  1633. }
  1634. return result;
  1635. }
  1636. string VocData::integerToString(const int input_int)
  1637. {
  1638. string result;
  1639. stringstream ss;
  1640. if ((ss << input_int).fail())
  1641. {
  1642. CV_Error(Error::StsBadArg,"could not perform integer to string conversion");
  1643. }
  1644. result = ss.str();
  1645. return result;
  1646. }
  1647. string VocData::checkFilenamePathsep( const string filename, bool add_trailing_slash )
  1648. {
  1649. string filename_new = filename;
  1650. size_t pos = filename_new.find("\\\\");
  1651. while (pos != filename_new.npos)
  1652. {
  1653. filename_new.replace(pos,2,"/");
  1654. pos = filename_new.find("\\\\", pos);
  1655. }
  1656. pos = filename_new.find("\\");
  1657. while (pos != filename_new.npos)
  1658. {
  1659. filename_new.replace(pos,1,"/");
  1660. pos = filename_new.find("\\", pos);
  1661. }
  1662. if (add_trailing_slash)
  1663. {
  1664. //add training slash if this is missing
  1665. if (filename_new.rfind("/") != filename_new.length()-1) filename_new += "/";
  1666. }
  1667. return filename_new;
  1668. }
  1669. void VocData::convertImageCodesToObdImages(const vector<string>& image_codes, vector<ObdImage>& images)
  1670. {
  1671. images.clear();
  1672. images.reserve(image_codes.size());
  1673. string path;
  1674. //transfer to output arrays
  1675. for (size_t i = 0; i < image_codes.size(); ++i)
  1676. {
  1677. //generate image path and indices from extracted string code
  1678. path = getImagePath(image_codes[i]);
  1679. images.push_back(ObdImage(image_codes[i], path));
  1680. }
  1681. }
  1682. //Extract text from within a given tag from an XML file
  1683. //-----------------------------------------------------
  1684. //INPUTS:
  1685. // - src XML source file
  1686. // - tag XML tag delimiting block to extract
  1687. // - searchpos position within src at which to start search
  1688. //OUTPUTS:
  1689. // - tag_contents text extracted between <tag> and </tag> tags
  1690. //RETURN VALUE:
  1691. // - the position of the final character extracted in tag_contents within src
  1692. // (can be used to call extractXMLBlock recursively to extract multiple blocks)
  1693. // returns -1 if the tag could not be found
  1694. int VocData::extractXMLBlock(const string src, const string tag, const int searchpos, string& tag_contents)
  1695. {
  1696. size_t startpos, next_startpos, endpos;
  1697. int embed_count = 1;
  1698. //find position of opening tag
  1699. startpos = src.find("<" + tag + ">", searchpos);
  1700. if (startpos == string::npos) return -1;
  1701. //initialize endpos -
  1702. // start searching for end tag anywhere after opening tag
  1703. endpos = startpos;
  1704. //find position of next opening tag
  1705. next_startpos = src.find("<" + tag + ">", startpos+1);
  1706. //match opening tags with closing tags, and only
  1707. //accept final closing tag of same level as original
  1708. //opening tag
  1709. while (embed_count > 0)
  1710. {
  1711. endpos = src.find("</" + tag + ">", endpos+1);
  1712. if (endpos == string::npos) return -1;
  1713. //the next code is only executed if there are embedded tags with the same name
  1714. if (next_startpos != string::npos)
  1715. {
  1716. while (next_startpos<endpos)
  1717. {
  1718. //counting embedded start tags
  1719. ++embed_count;
  1720. next_startpos = src.find("<" + tag + ">", next_startpos+1);
  1721. if (next_startpos == string::npos) break;
  1722. }
  1723. }
  1724. //passing end tag so decrement nesting level
  1725. --embed_count;
  1726. }
  1727. //finally, extract the tag region
  1728. startpos += tag.length() + 2;
  1729. if (startpos > src.length()) return -1;
  1730. if (endpos > src.length()) return -1;
  1731. tag_contents = src.substr(startpos,endpos-startpos);
  1732. return static_cast<int>(endpos);
  1733. }
  1734. /****************************************************************************************\
  1735. * Sample on image classification *
  1736. \****************************************************************************************/
  1737. //
  1738. // This part of the code was a little refactor
  1739. //
  1740. struct DDMParams
  1741. {
  1742. DDMParams() : detectorType("SURF"), descriptorType("SURF"), matcherType("BruteForce") {}
  1743. DDMParams( const string _detectorType, const string _descriptorType, const string& _matcherType ) :
  1744. detectorType(_detectorType), descriptorType(_descriptorType), matcherType(_matcherType){}
  1745. void read( const FileNode& fn )
  1746. {
  1747. fn["detectorType"] >> detectorType;
  1748. fn["descriptorType"] >> descriptorType;
  1749. fn["matcherType"] >> matcherType;
  1750. }
  1751. void write( FileStorage& fs ) const
  1752. {
  1753. fs << "detectorType" << detectorType;
  1754. fs << "descriptorType" << descriptorType;
  1755. fs << "matcherType" << matcherType;
  1756. }
  1757. void print() const
  1758. {
  1759. cout << "detectorType: " << detectorType << endl;
  1760. cout << "descriptorType: " << descriptorType << endl;
  1761. cout << "matcherType: " << matcherType << endl;
  1762. }
  1763. string detectorType;
  1764. string descriptorType;
  1765. string matcherType;
  1766. };
  1767. struct VocabTrainParams
  1768. {
  1769. VocabTrainParams() : trainObjClass("chair"), vocabSize(1000), memoryUse(200), descProportion(0.3f) {}
  1770. VocabTrainParams( const string _trainObjClass, size_t _vocabSize, size_t _memoryUse, float _descProportion ) :
  1771. trainObjClass(_trainObjClass), vocabSize((int)_vocabSize), memoryUse((int)_memoryUse), descProportion(_descProportion) {}
  1772. void read( const FileNode& fn )
  1773. {
  1774. fn["trainObjClass"] >> trainObjClass;
  1775. fn["vocabSize"] >> vocabSize;
  1776. fn["memoryUse"] >> memoryUse;
  1777. fn["descProportion"] >> descProportion;
  1778. }
  1779. void write( FileStorage& fs ) const
  1780. {
  1781. fs << "trainObjClass" << trainObjClass;
  1782. fs << "vocabSize" << vocabSize;
  1783. fs << "memoryUse" << memoryUse;
  1784. fs << "descProportion" << descProportion;
  1785. }
  1786. void print() const
  1787. {
  1788. cout << "trainObjClass: " << trainObjClass << endl;
  1789. cout << "vocabSize: " << vocabSize << endl;
  1790. cout << "memoryUse: " << memoryUse << endl;
  1791. cout << "descProportion: " << descProportion << endl;
  1792. }
  1793. string trainObjClass; // Object class used for training visual vocabulary.
  1794. // It shouldn't matter which object class is specified here - visual vocab will still be the same.
  1795. int vocabSize; //number of visual words in vocabulary to train
  1796. int memoryUse; // Memory to preallocate (in MB) when training vocab.
  1797. // Change this depending on the size of the dataset/available memory.
  1798. float descProportion; // Specifies the number of descriptors to use from each image as a proportion of the total num descs.
  1799. };
  1800. struct SVMTrainParamsExt
  1801. {
  1802. SVMTrainParamsExt() : descPercent(0.5f), targetRatio(0.4f), balanceClasses(true) {}
  1803. SVMTrainParamsExt( float _descPercent, float _targetRatio, bool _balanceClasses ) :
  1804. descPercent(_descPercent), targetRatio(_targetRatio), balanceClasses(_balanceClasses) {}
  1805. void read( const FileNode& fn )
  1806. {
  1807. fn["descPercent"] >> descPercent;
  1808. fn["targetRatio"] >> targetRatio;
  1809. fn["balanceClasses"] >> balanceClasses;
  1810. }
  1811. void write( FileStorage& fs ) const
  1812. {
  1813. fs << "descPercent" << descPercent;
  1814. fs << "targetRatio" << targetRatio;
  1815. fs << "balanceClasses" << balanceClasses;
  1816. }
  1817. void print() const
  1818. {
  1819. cout << "descPercent: " << descPercent << endl;
  1820. cout << "targetRatio: " << targetRatio << endl;
  1821. cout << "balanceClasses: " << balanceClasses << endl;
  1822. }
  1823. float descPercent; // Percentage of extracted descriptors to use for training.
  1824. float targetRatio; // Try to get this ratio of positive to negative samples (minimum).
  1825. bool balanceClasses; // Balance class weights by number of samples in each (if true cSvmTrainTargetRatio is ignored).
  1826. };
  1827. static void readUsedParams( const FileNode& fn, string& vocName, DDMParams& ddmParams, VocabTrainParams& vocabTrainParams, SVMTrainParamsExt& svmTrainParamsExt )
  1828. {
  1829. fn["vocName"] >> vocName;
  1830. FileNode currFn = fn;
  1831. currFn = fn["ddmParams"];
  1832. ddmParams.read( currFn );
  1833. currFn = fn["vocabTrainParams"];
  1834. vocabTrainParams.read( currFn );
  1835. currFn = fn["svmTrainParamsExt"];
  1836. svmTrainParamsExt.read( currFn );
  1837. }
  1838. static void writeUsedParams( FileStorage& fs, const string& vocName, const DDMParams& ddmParams, const VocabTrainParams& vocabTrainParams, const SVMTrainParamsExt& svmTrainParamsExt )
  1839. {
  1840. fs << "vocName" << vocName;
  1841. fs << "ddmParams" << "{";
  1842. ddmParams.write(fs);
  1843. fs << "}";
  1844. fs << "vocabTrainParams" << "{";
  1845. vocabTrainParams.write(fs);
  1846. fs << "}";
  1847. fs << "svmTrainParamsExt" << "{";
  1848. svmTrainParamsExt.write(fs);
  1849. fs << "}";
  1850. }
  1851. static void printUsedParams( const string& vocPath, const string& resDir,
  1852. const DDMParams& ddmParams, const VocabTrainParams& vocabTrainParams,
  1853. const SVMTrainParamsExt& svmTrainParamsExt )
  1854. {
  1855. cout << "CURRENT CONFIGURATION" << endl;
  1856. cout << "----------------------------------------------------------------" << endl;
  1857. cout << "vocPath: " << vocPath << endl;
  1858. cout << "resDir: " << resDir << endl;
  1859. cout << endl; ddmParams.print();
  1860. cout << endl; vocabTrainParams.print();
  1861. cout << endl; svmTrainParamsExt.print();
  1862. cout << "----------------------------------------------------------------" << endl << endl;
  1863. }
  1864. static bool readVocabulary( const string& filename, Mat& vocabulary )
  1865. {
  1866. cout << "Reading vocabulary...";
  1867. FileStorage fs( filename, FileStorage::READ );
  1868. if( fs.isOpened() )
  1869. {
  1870. fs["vocabulary"] >> vocabulary;
  1871. cout << "done" << endl;
  1872. return true;
  1873. }
  1874. return false;
  1875. }
  1876. static bool writeVocabulary( const string& filename, const Mat& vocabulary )
  1877. {
  1878. cout << "Saving vocabulary..." << endl;
  1879. FileStorage fs( filename, FileStorage::WRITE );
  1880. if( fs.isOpened() )
  1881. {
  1882. fs << "vocabulary" << vocabulary;
  1883. return true;
  1884. }
  1885. return false;
  1886. }
  1887. static Mat trainVocabulary( const string& filename, VocData& vocData, const VocabTrainParams& trainParams,
  1888. const Ptr<FeatureDetector>& fdetector, const Ptr<DescriptorExtractor>& dextractor )
  1889. {
  1890. Mat vocabulary;
  1891. if( !readVocabulary( filename, vocabulary) )
  1892. {
  1893. CV_Assert( dextractor->descriptorType() == CV_32FC1 );
  1894. const int elemSize = CV_ELEM_SIZE(dextractor->descriptorType());
  1895. const int descByteSize = dextractor->descriptorSize() * elemSize;
  1896. const int bytesInMB = 1048576;
  1897. const int maxDescCount = (trainParams.memoryUse * bytesInMB) / descByteSize; // Total number of descs to use for training.
  1898. cout << "Extracting VOC data..." << endl;
  1899. vector<ObdImage> images;
  1900. vector<char> objectPresent;
  1901. vocData.getClassImages( trainParams.trainObjClass, CV_OBD_TRAIN, images, objectPresent );
  1902. cout << "Computing descriptors..." << endl;
  1903. RNG& rng = theRNG();
  1904. TermCriteria terminate_criterion;
  1905. terminate_criterion.epsilon = FLT_EPSILON;
  1906. BOWKMeansTrainer bowTrainer( trainParams.vocabSize, terminate_criterion, 3, KMEANS_PP_CENTERS );
  1907. while( images.size() > 0 )
  1908. {
  1909. if( bowTrainer.descriptorsCount() > maxDescCount )
  1910. {
  1911. #ifdef DEBUG_DESC_PROGRESS
  1912. cout << "Breaking due to full memory ( descriptors count = " << bowTrainer.descriptorsCount()
  1913. << "; descriptor size in bytes = " << descByteSize << "; all used memory = "
  1914. << bowTrainer.descriptorsCount()*descByteSize << endl;
  1915. #endif
  1916. break;
  1917. }
  1918. // Randomly pick an image from the dataset which hasn't yet been seen
  1919. // and compute the descriptors from that image.
  1920. int randImgIdx = rng( (unsigned)images.size() );
  1921. Mat colorImage = imread( images[randImgIdx].path );
  1922. vector<KeyPoint> imageKeypoints;
  1923. fdetector->detect( colorImage, imageKeypoints );
  1924. Mat imageDescriptors;
  1925. dextractor->compute( colorImage, imageKeypoints, imageDescriptors );
  1926. //check that there were descriptors calculated for the current image
  1927. if( !imageDescriptors.empty() )
  1928. {
  1929. int descCount = imageDescriptors.rows;
  1930. // Extract trainParams.descProportion descriptors from the image, breaking if the 'allDescriptors' matrix becomes full
  1931. int descsToExtract = static_cast<int>(trainParams.descProportion * static_cast<float>(descCount));
  1932. // Fill mask of used descriptors
  1933. vector<char> usedMask( descCount, false );
  1934. fill( usedMask.begin(), usedMask.begin() + descsToExtract, true );
  1935. for( int i = 0; i < descCount; i++ )
  1936. {
  1937. int i1 = rng(descCount), i2 = rng(descCount);
  1938. char tmp = usedMask[i1]; usedMask[i1] = usedMask[i2]; usedMask[i2] = tmp;
  1939. }
  1940. for( int i = 0; i < descCount; i++ )
  1941. {
  1942. if( usedMask[i] && bowTrainer.descriptorsCount() < maxDescCount )
  1943. bowTrainer.add( imageDescriptors.row(i) );
  1944. }
  1945. }
  1946. #ifdef DEBUG_DESC_PROGRESS
  1947. cout << images.size() << " images left, " << images[randImgIdx].id << " processed - "
  1948. <</* descs_extracted << "/" << image_descriptors.rows << " extracted - " << */
  1949. cvRound((static_cast<double>(bowTrainer.descriptorsCount())/static_cast<double>(maxDescCount))*100.0)
  1950. << " % memory used" << ( imageDescriptors.empty() ? " -> no descriptors extracted, skipping" : "") << endl;
  1951. #endif
  1952. // Delete the current element from images so it is not added again
  1953. images.erase( images.begin() + randImgIdx );
  1954. }
  1955. cout << "Maximum allowed descriptor count: " << maxDescCount << ", Actual descriptor count: " << bowTrainer.descriptorsCount() << endl;
  1956. cout << "Training vocabulary..." << endl;
  1957. vocabulary = bowTrainer.cluster();
  1958. if( !writeVocabulary(filename, vocabulary) )
  1959. {
  1960. cout << "Error: file " << filename << " can not be opened to write" << endl;
  1961. exit(-1);
  1962. }
  1963. }
  1964. return vocabulary;
  1965. }
  1966. static bool readBowImageDescriptor( const string& file, Mat& bowImageDescriptor )
  1967. {
  1968. FileStorage fs( file, FileStorage::READ );
  1969. if( fs.isOpened() )
  1970. {
  1971. fs["imageDescriptor"] >> bowImageDescriptor;
  1972. return true;
  1973. }
  1974. return false;
  1975. }
  1976. static bool writeBowImageDescriptor( const string& file, const Mat& bowImageDescriptor )
  1977. {
  1978. FileStorage fs( file, FileStorage::WRITE );
  1979. if( fs.isOpened() )
  1980. {
  1981. fs << "imageDescriptor" << bowImageDescriptor;
  1982. return true;
  1983. }
  1984. return false;
  1985. }
  1986. // Load in the bag of words vectors for a set of images, from file if possible
  1987. static void calculateImageDescriptors( const vector<ObdImage>& images, vector<Mat>& imageDescriptors,
  1988. Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
  1989. const string& resPath )
  1990. {
  1991. CV_Assert( !bowExtractor->getVocabulary().empty() );
  1992. imageDescriptors.resize( images.size() );
  1993. for( size_t i = 0; i < images.size(); i++ )
  1994. {
  1995. string filename = resPath + bowImageDescriptorsDir + "/" + images[i].id + ".xml.gz";
  1996. if( readBowImageDescriptor( filename, imageDescriptors[i] ) )
  1997. {
  1998. #ifdef DEBUG_DESC_PROGRESS
  1999. cout << "Loaded bag of word vector for image " << i+1 << " of " << images.size() << " (" << images[i].id << ")" << endl;
  2000. #endif
  2001. }
  2002. else
  2003. {
  2004. Mat colorImage = imread( images[i].path );
  2005. #ifdef DEBUG_DESC_PROGRESS
  2006. cout << "Computing descriptors for image " << i+1 << " of " << images.size() << " (" << images[i].id << ")" << flush;
  2007. #endif
  2008. vector<KeyPoint> keypoints;
  2009. fdetector->detect( colorImage, keypoints );
  2010. #ifdef DEBUG_DESC_PROGRESS
  2011. cout << " + generating BoW vector" << std::flush;
  2012. #endif
  2013. bowExtractor->compute( colorImage, keypoints, imageDescriptors[i] );
  2014. #ifdef DEBUG_DESC_PROGRESS
  2015. cout << " ...DONE " << static_cast<int>(static_cast<float>(i+1)/static_cast<float>(images.size())*100.0)
  2016. << " % complete" << endl;
  2017. #endif
  2018. if( !imageDescriptors[i].empty() )
  2019. {
  2020. if( !writeBowImageDescriptor( filename, imageDescriptors[i] ) )
  2021. {
  2022. cout << "Error: file " << filename << "can not be opened to write bow image descriptor" << endl;
  2023. exit(-1);
  2024. }
  2025. }
  2026. }
  2027. }
  2028. }
  2029. static void removeEmptyBowImageDescriptors( vector<ObdImage>& images, vector<Mat>& bowImageDescriptors,
  2030. vector<char>& objectPresent )
  2031. {
  2032. CV_Assert( !images.empty() );
  2033. for( int i = (int)images.size() - 1; i >= 0; i-- )
  2034. {
  2035. bool res = bowImageDescriptors[i].empty();
  2036. if( res )
  2037. {
  2038. cout << "Removing image " << images[i].id << " due to no descriptors..." << endl;
  2039. images.erase( images.begin() + i );
  2040. bowImageDescriptors.erase( bowImageDescriptors.begin() + i );
  2041. objectPresent.erase( objectPresent.begin() + i );
  2042. }
  2043. }
  2044. }
  2045. static void removeBowImageDescriptorsByCount( vector<ObdImage>& images, vector<Mat> bowImageDescriptors, vector<char> objectPresent,
  2046. const SVMTrainParamsExt& svmParamsExt, int descsToDelete )
  2047. {
  2048. RNG& rng = theRNG();
  2049. int pos_ex = (int)std::count( objectPresent.begin(), objectPresent.end(), (char)1 );
  2050. int neg_ex = (int)std::count( objectPresent.begin(), objectPresent.end(), (char)0 );
  2051. while( descsToDelete != 0 )
  2052. {
  2053. int randIdx = rng((unsigned)images.size());
  2054. // Prefer positive training examples according to svmParamsExt.targetRatio if required
  2055. if( objectPresent[randIdx] )
  2056. {
  2057. if( (static_cast<float>(pos_ex)/static_cast<float>(neg_ex+pos_ex) < svmParamsExt.targetRatio) &&
  2058. (neg_ex > 0) && (svmParamsExt.balanceClasses == false) )
  2059. { continue; }
  2060. else
  2061. { pos_ex--; }
  2062. }
  2063. else
  2064. { neg_ex--; }
  2065. images.erase( images.begin() + randIdx );
  2066. bowImageDescriptors.erase( bowImageDescriptors.begin() + randIdx );
  2067. objectPresent.erase( objectPresent.begin() + randIdx );
  2068. descsToDelete--;
  2069. }
  2070. CV_Assert( bowImageDescriptors.size() == objectPresent.size() );
  2071. }
  2072. static void setSVMParams( Ptr<SVM> & svm, const Mat& responses, bool balanceClasses )
  2073. {
  2074. int pos_ex = countNonZero(responses == 1);
  2075. int neg_ex = countNonZero(responses == -1);
  2076. cout << pos_ex << " positive training samples; " << neg_ex << " negative training samples" << endl;
  2077. svm->setType(SVM::C_SVC);
  2078. svm->setKernel(SVM::RBF);
  2079. if( balanceClasses )
  2080. {
  2081. Mat class_wts( 2, 1, CV_32FC1 );
  2082. // The first training sample determines the '+1' class internally, even if it is negative,
  2083. // so store whether this is the case so that the class weights can be reversed accordingly.
  2084. bool reversed_classes = (responses.at<float>(0) < 0.f);
  2085. if( reversed_classes == false )
  2086. {
  2087. class_wts.at<float>(0) = static_cast<float>(pos_ex)/static_cast<float>(pos_ex+neg_ex); // weighting for costs of positive class + 1 (i.e. cost of false positive - larger gives greater cost)
  2088. class_wts.at<float>(1) = static_cast<float>(neg_ex)/static_cast<float>(pos_ex+neg_ex); // weighting for costs of negative class - 1 (i.e. cost of false negative)
  2089. }
  2090. else
  2091. {
  2092. class_wts.at<float>(0) = static_cast<float>(neg_ex)/static_cast<float>(pos_ex+neg_ex);
  2093. class_wts.at<float>(1) = static_cast<float>(pos_ex)/static_cast<float>(pos_ex+neg_ex);
  2094. }
  2095. svm->setClassWeights(class_wts);
  2096. }
  2097. }
  2098. static void setSVMTrainAutoParams( ParamGrid& c_grid, ParamGrid& gamma_grid,
  2099. ParamGrid& p_grid, ParamGrid& nu_grid,
  2100. ParamGrid& coef_grid, ParamGrid& degree_grid )
  2101. {
  2102. c_grid = SVM::getDefaultGrid(SVM::C);
  2103. gamma_grid = SVM::getDefaultGrid(SVM::GAMMA);
  2104. p_grid = SVM::getDefaultGrid(SVM::P);
  2105. p_grid.logStep = 0;
  2106. nu_grid = SVM::getDefaultGrid(SVM::NU);
  2107. nu_grid.logStep = 0;
  2108. coef_grid = SVM::getDefaultGrid(SVM::COEF);
  2109. coef_grid.logStep = 0;
  2110. degree_grid = SVM::getDefaultGrid(SVM::DEGREE);
  2111. degree_grid.logStep = 0;
  2112. }
  2113. static Ptr<SVM> trainSVMClassifier( const SVMTrainParamsExt& svmParamsExt, const string& objClassName, VocData& vocData,
  2114. Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
  2115. const string& resPath )
  2116. {
  2117. /* first check if a previously trained svm for the current class has been saved to file */
  2118. string svmFilename = resPath + svmsDir + "/" + objClassName + ".xml.gz";
  2119. Ptr<SVM> svm;
  2120. FileStorage fs( svmFilename, FileStorage::READ);
  2121. if( fs.isOpened() )
  2122. {
  2123. cout << "*** LOADING SVM CLASSIFIER FOR CLASS " << objClassName << " ***" << endl;
  2124. svm = StatModel::load<SVM>( svmFilename );
  2125. }
  2126. else
  2127. {
  2128. cout << "*** TRAINING CLASSIFIER FOR CLASS " << objClassName << " ***" << endl;
  2129. cout << "CALCULATING BOW VECTORS FOR TRAINING SET OF " << objClassName << "..." << endl;
  2130. // Get classification ground truth for images in the training set
  2131. vector<ObdImage> images;
  2132. vector<Mat> bowImageDescriptors;
  2133. vector<char> objectPresent;
  2134. vocData.getClassImages( objClassName, CV_OBD_TRAIN, images, objectPresent );
  2135. // Compute the bag of words vector for each image in the training set.
  2136. calculateImageDescriptors( images, bowImageDescriptors, bowExtractor, fdetector, resPath );
  2137. // Remove any images for which descriptors could not be calculated
  2138. removeEmptyBowImageDescriptors( images, bowImageDescriptors, objectPresent );
  2139. CV_Assert( svmParamsExt.descPercent > 0.f && svmParamsExt.descPercent <= 1.f );
  2140. if( svmParamsExt.descPercent < 1.f )
  2141. {
  2142. int descsToDelete = static_cast<int>(static_cast<float>(images.size())*(1.0-svmParamsExt.descPercent));
  2143. cout << "Using " << (images.size() - descsToDelete) << " of " << images.size() <<
  2144. " descriptors for training (" << svmParamsExt.descPercent*100.0 << " %)" << endl;
  2145. removeBowImageDescriptorsByCount( images, bowImageDescriptors, objectPresent, svmParamsExt, descsToDelete );
  2146. }
  2147. // Prepare the input matrices for SVM training.
  2148. Mat trainData( (int)images.size(), bowExtractor->getVocabulary().rows, CV_32FC1 );
  2149. Mat responses( (int)images.size(), 1, CV_32SC1 );
  2150. // Transfer bag of words vectors and responses across to the training data matrices
  2151. for( size_t imageIdx = 0; imageIdx < images.size(); imageIdx++ )
  2152. {
  2153. // Transfer image descriptor (bag of words vector) to training data matrix
  2154. Mat submat = trainData.row((int)imageIdx);
  2155. if( bowImageDescriptors[imageIdx].cols != bowExtractor->descriptorSize() )
  2156. {
  2157. cout << "Error: computed bow image descriptor size " << bowImageDescriptors[imageIdx].cols
  2158. << " differs from vocabulary size" << bowExtractor->getVocabulary().cols << endl;
  2159. exit(-1);
  2160. }
  2161. bowImageDescriptors[imageIdx].copyTo( submat );
  2162. // Set response value
  2163. responses.at<int>((int)imageIdx) = objectPresent[imageIdx] ? 1 : -1;
  2164. }
  2165. cout << "TRAINING SVM FOR CLASS ..." << objClassName << "..." << endl;
  2166. svm = SVM::create();
  2167. setSVMParams( svm, responses, svmParamsExt.balanceClasses );
  2168. ParamGrid c_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid;
  2169. setSVMTrainAutoParams( c_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid );
  2170. svm->trainAuto(TrainData::create(trainData, ROW_SAMPLE, responses), 10,
  2171. c_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid);
  2172. cout << "SVM TRAINING FOR CLASS " << objClassName << " COMPLETED" << endl;
  2173. svm->save( svmFilename );
  2174. cout << "SAVED CLASSIFIER TO FILE" << endl;
  2175. }
  2176. return svm;
  2177. }
  2178. static void computeConfidences( const Ptr<SVM>& svm, const string& objClassName, VocData& vocData,
  2179. Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
  2180. const string& resPath )
  2181. {
  2182. cout << "*** CALCULATING CONFIDENCES FOR CLASS " << objClassName << " ***" << endl;
  2183. cout << "CALCULATING BOW VECTORS FOR TEST SET OF " << objClassName << "..." << endl;
  2184. // Get classification ground truth for images in the test set
  2185. vector<ObdImage> images;
  2186. vector<Mat> bowImageDescriptors;
  2187. vector<char> objectPresent;
  2188. vocData.getClassImages( objClassName, CV_OBD_TEST, images, objectPresent );
  2189. // Compute the bag of words vector for each image in the test set
  2190. calculateImageDescriptors( images, bowImageDescriptors, bowExtractor, fdetector, resPath );
  2191. // Remove any images for which descriptors could not be calculated
  2192. removeEmptyBowImageDescriptors( images, bowImageDescriptors, objectPresent);
  2193. // Use the bag of words vectors to calculate classifier output for each image in test set
  2194. cout << "CALCULATING CONFIDENCE SCORES FOR CLASS " << objClassName << "..." << endl;
  2195. vector<float> confidences( images.size() );
  2196. float signMul = 1.f;
  2197. for( size_t imageIdx = 0; imageIdx < images.size(); imageIdx++ )
  2198. {
  2199. if( imageIdx == 0 )
  2200. {
  2201. // In the first iteration, determine the sign of the positive class
  2202. float classVal = confidences[imageIdx] = svm->predict( bowImageDescriptors[imageIdx], noArray(), 0 );
  2203. float scoreVal = confidences[imageIdx] = svm->predict( bowImageDescriptors[imageIdx], noArray(), StatModel::RAW_OUTPUT );
  2204. signMul = (classVal < 0) == (scoreVal < 0) ? 1.f : -1.f;
  2205. }
  2206. // svm output of decision function
  2207. confidences[imageIdx] = signMul * svm->predict( bowImageDescriptors[imageIdx], noArray(), StatModel::RAW_OUTPUT );
  2208. }
  2209. cout << "WRITING QUERY RESULTS TO VOC RESULTS FILE FOR CLASS " << objClassName << "..." << endl;
  2210. vocData.writeClassifierResultsFile( resPath + plotsDir, objClassName, CV_OBD_TEST, images, confidences, 1, true );
  2211. cout << "DONE - " << objClassName << endl;
  2212. cout << "---------------------------------------------------------------" << endl;
  2213. }
  2214. static void computeGnuPlotOutput( const string& resPath, const string& objClassName, VocData& vocData )
  2215. {
  2216. vector<float> precision, recall;
  2217. float ap;
  2218. const string resultFile = vocData.getResultsFilename( objClassName, CV_VOC_TASK_CLASSIFICATION, CV_OBD_TEST);
  2219. const string plotFile = resultFile.substr(0, resultFile.size()-4) + ".plt";
  2220. cout << "Calculating precision recall curve for class '" <<objClassName << "'" << endl;
  2221. vocData.calcClassifierPrecRecall( resPath + plotsDir + "/" + resultFile, precision, recall, ap, true );
  2222. cout << "Outputting to GNUPlot file..." << endl;
  2223. vocData.savePrecRecallToGnuplot( resPath + plotsDir + "/" + plotFile, precision, recall, ap, objClassName, CV_VOC_PLOT_PNG );
  2224. }
  2225. static Ptr<Feature2D> createByName(const String& name)
  2226. {
  2227. if( name == "SIFT" )
  2228. return SIFT::create();
  2229. if( name == "SURF" )
  2230. return SURF::create();
  2231. if( name == "ORB" )
  2232. return ORB::create();
  2233. if( name == "BRISK" )
  2234. return BRISK::create();
  2235. if( name == "KAZE" )
  2236. return KAZE::create();
  2237. if( name == "AKAZE" )
  2238. return AKAZE::create();
  2239. return Ptr<Feature2D>();
  2240. }
  2241. int main(int argc, char** argv)
  2242. {
  2243. if( argc != 3 && argc != 6 )
  2244. {
  2245. help(argv);
  2246. return -1;
  2247. }
  2248. const string vocPath = argv[1], resPath = argv[2];
  2249. // Read or set default parameters
  2250. string vocName;
  2251. DDMParams ddmParams;
  2252. VocabTrainParams vocabTrainParams;
  2253. SVMTrainParamsExt svmTrainParamsExt;
  2254. makeUsedDirs( resPath );
  2255. FileStorage paramsFS( resPath + "/" + paramsFile, FileStorage::READ );
  2256. if( paramsFS.isOpened() )
  2257. {
  2258. readUsedParams( paramsFS.root(), vocName, ddmParams, vocabTrainParams, svmTrainParamsExt );
  2259. CV_Assert( vocName == getVocName(vocPath) );
  2260. }
  2261. else
  2262. {
  2263. vocName = getVocName(vocPath);
  2264. if( argc!= 6 )
  2265. {
  2266. cout << "Feature detector, descriptor extractor, descriptor matcher must be set" << endl;
  2267. return -1;
  2268. }
  2269. ddmParams = DDMParams( argv[3], argv[4], argv[5] ); // from command line
  2270. // vocabTrainParams and svmTrainParamsExt is set by defaults
  2271. paramsFS.open( resPath + "/" + paramsFile, FileStorage::WRITE );
  2272. if( paramsFS.isOpened() )
  2273. {
  2274. writeUsedParams( paramsFS, vocName, ddmParams, vocabTrainParams, svmTrainParamsExt );
  2275. paramsFS.release();
  2276. }
  2277. else
  2278. {
  2279. cout << "File " << (resPath + "/" + paramsFile) << "can not be opened to write" << endl;
  2280. return -1;
  2281. }
  2282. }
  2283. // Create detector, descriptor, matcher.
  2284. if( ddmParams.detectorType != ddmParams.descriptorType )
  2285. {
  2286. cout << "detector and descriptor should be the same\n";
  2287. return -1;
  2288. }
  2289. Ptr<Feature2D> featureDetector = createByName( ddmParams.detectorType );
  2290. Ptr<DescriptorExtractor> descExtractor = featureDetector;
  2291. Ptr<BOWImgDescriptorExtractor> bowExtractor;
  2292. if( !featureDetector || !descExtractor )
  2293. {
  2294. cout << "featureDetector or descExtractor was not created" << endl;
  2295. return -1;
  2296. }
  2297. {
  2298. Ptr<DescriptorMatcher> descMatcher = DescriptorMatcher::create( ddmParams.matcherType );
  2299. if( !featureDetector || !descExtractor || !descMatcher )
  2300. {
  2301. cout << "descMatcher was not created" << endl;
  2302. return -1;
  2303. }
  2304. bowExtractor = makePtr<BOWImgDescriptorExtractor>( descExtractor, descMatcher );
  2305. }
  2306. // Print configuration to screen
  2307. printUsedParams( vocPath, resPath, ddmParams, vocabTrainParams, svmTrainParamsExt );
  2308. // Create object to work with VOC
  2309. VocData vocData( vocPath, false );
  2310. // 1. Train visual word vocabulary if a pre-calculated vocabulary file doesn't already exist from previous run
  2311. Mat vocabulary = trainVocabulary( resPath + "/" + vocabularyFile, vocData, vocabTrainParams,
  2312. featureDetector, descExtractor );
  2313. bowExtractor->setVocabulary( vocabulary );
  2314. // 2. Train a classifier and run a sample query for each object class
  2315. const vector<string>& objClasses = vocData.getObjectClasses(); // object class list
  2316. for( size_t classIdx = 0; classIdx < objClasses.size(); ++classIdx )
  2317. {
  2318. // Train a classifier on train dataset
  2319. Ptr<SVM> svm = trainSVMClassifier( svmTrainParamsExt, objClasses[classIdx], vocData,
  2320. bowExtractor, featureDetector, resPath );
  2321. // Now use the classifier over all images on the test dataset and rank according to score order
  2322. // also calculating precision-recall etc.
  2323. computeConfidences( svm, objClasses[classIdx], vocData,
  2324. bowExtractor, featureDetector, resPath );
  2325. // Calculate precision/recall/ap and use GNUPlot to output to a pdf file
  2326. computeGnuPlotOutput( resPath, objClasses[classIdx], vocData );
  2327. }
  2328. return 0;
  2329. }
  2330. #else
  2331. int main()
  2332. {
  2333. std::cerr << "OpenCV was built without ml module" << std::endl;
  2334. return 0;
  2335. }
  2336. #endif // HAVE_OPENCV_ML