old_ml_tree.cpp 128 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148
  1. /*M///////////////////////////////////////////////////////////////////////////////////////
  2. //
  3. // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
  4. //
  5. // By downloading, copying, installing or using the software you agree to this license.
  6. // If you do not agree to this license, do not download, install,
  7. // copy or use the software.
  8. //
  9. //
  10. // Intel License Agreement
  11. //
  12. // Copyright (C) 2000, Intel Corporation, all rights reserved.
  13. // Third party copyrights are property of their respective owners.
  14. //
  15. // Redistribution and use in source and binary forms, with or without modification,
  16. // are permitted provided that the following conditions are met:
  17. //
  18. // * Redistribution's of source code must retain the above copyright notice,
  19. // this list of conditions and the following disclaimer.
  20. //
  21. // * Redistribution's in binary form must reproduce the above copyright notice,
  22. // this list of conditions and the following disclaimer in the documentation
  23. // and/or other materials provided with the distribution.
  24. //
  25. // * The name of Intel Corporation may not be used to endorse or promote products
  26. // derived from this software without specific prior written permission.
  27. //
  28. // This software is provided by the copyright holders and contributors "as is" and
  29. // any express or implied warranties, including, but not limited to, the implied
  30. // warranties of merchantability and fitness for a particular purpose are disclaimed.
  31. // In no event shall the Intel Corporation or contributors be liable for any direct,
  32. // indirect, incidental, special, exemplary, or consequential damages
  33. // (including, but not limited to, procurement of substitute goods or services;
  34. // loss of use, data, or profits; or business interruption) however caused
  35. // and on any theory of liability, whether in contract, strict liability,
  36. // or tort (including negligence or otherwise) arising in any way out of
  37. // the use of this software, even if advised of the possibility of such damage.
  38. //
  39. //M*/
  40. #include "old_ml_precomp.hpp"
  41. #include <ctype.h>
  42. using namespace cv;
  43. static const float ord_nan = FLT_MAX*0.5f;
  44. static const int min_block_size = 1 << 16;
  45. static const int block_size_delta = 1 << 10;
  46. CvDTreeTrainData::CvDTreeTrainData()
  47. {
  48. var_idx = var_type = cat_count = cat_ofs = cat_map =
  49. priors = priors_mult = counts = direction = split_buf = responses_copy = 0;
  50. buf = 0;
  51. tree_storage = temp_storage = 0;
  52. clear();
  53. }
  54. CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
  55. const CvMat* _responses, const CvMat* _var_idx,
  56. const CvMat* _sample_idx, const CvMat* _var_type,
  57. const CvMat* _missing_mask, const CvDTreeParams& _params,
  58. bool _shared, bool _add_labels )
  59. {
  60. var_idx = var_type = cat_count = cat_ofs = cat_map =
  61. priors = priors_mult = counts = direction = split_buf = responses_copy = 0;
  62. buf = 0;
  63. tree_storage = temp_storage = 0;
  64. set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
  65. _var_type, _missing_mask, _params, _shared, _add_labels );
  66. }
  67. CvDTreeTrainData::~CvDTreeTrainData()
  68. {
  69. clear();
  70. }
  71. bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
  72. {
  73. bool ok = false;
  74. CV_FUNCNAME( "CvDTreeTrainData::set_params" );
  75. __BEGIN__;
  76. // set parameters
  77. params = _params;
  78. if( params.max_categories < 2 )
  79. CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
  80. params.max_categories = MIN( params.max_categories, 15 );
  81. if( params.max_depth < 0 )
  82. CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
  83. params.max_depth = MIN( params.max_depth, 25 );
  84. params.min_sample_count = MAX(params.min_sample_count,1);
  85. if( params.cv_folds < 0 )
  86. CV_ERROR( CV_StsOutOfRange,
  87. "params.cv_folds should be =0 (the tree is not pruned) "
  88. "or n>0 (tree is pruned using n-fold cross-validation)" );
  89. if( params.cv_folds == 1 )
  90. params.cv_folds = 0;
  91. if( params.regression_accuracy < 0 )
  92. CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
  93. ok = true;
  94. __END__;
  95. return ok;
  96. }
  97. template<typename T>
  98. class LessThanPtr
  99. {
  100. public:
  101. bool operator()(T* a, T* b) const { return *a < *b; }
  102. };
  103. template<typename T, typename Idx>
  104. class LessThanIdx
  105. {
  106. public:
  107. LessThanIdx( const T* _arr ) : arr(_arr) {}
  108. bool operator()(Idx a, Idx b) const { return arr[a] < arr[b]; }
  109. const T* arr;
  110. };
  111. class LessThanPairs
  112. {
  113. public:
  114. bool operator()(const CvPair16u32s& a, const CvPair16u32s& b) const { return *a.i < *b.i; }
  115. };
  116. void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
  117. const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
  118. const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
  119. bool _shared, bool _add_labels, bool _update_data )
  120. {
  121. CvMat* sample_indices = 0;
  122. CvMat* var_type0 = 0;
  123. CvMat* tmp_map = 0;
  124. int** int_ptr = 0;
  125. CvPair16u32s* pair16u32s_ptr = 0;
  126. CvDTreeTrainData* data = 0;
  127. float *_fdst = 0;
  128. int *_idst = 0;
  129. unsigned short* udst = 0;
  130. int* idst = 0;
  131. CV_FUNCNAME( "CvDTreeTrainData::set_data" );
  132. __BEGIN__;
  133. int sample_all = 0, r_type, cv_n;
  134. int total_c_count = 0;
  135. int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
  136. int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
  137. int vi, i, size;
  138. char err[100];
  139. const int *sidx = 0, *vidx = 0;
  140. uint64 effective_buf_size = 0;
  141. int effective_buf_height = 0, effective_buf_width = 0;
  142. if( _update_data && data_root )
  143. {
  144. data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
  145. _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
  146. // compare new and old train data
  147. if( !(data->var_count == var_count &&
  148. cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
  149. cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
  150. cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
  151. CV_ERROR( CV_StsBadArg,
  152. "The new training data must have the same types and the input and output variables "
  153. "and the same categories for categorical variables" );
  154. cvReleaseMat( &priors );
  155. cvReleaseMat( &priors_mult );
  156. cvReleaseMat( &buf );
  157. cvReleaseMat( &direction );
  158. cvReleaseMat( &split_buf );
  159. cvReleaseMemStorage( &temp_storage );
  160. priors = data->priors; data->priors = 0;
  161. priors_mult = data->priors_mult; data->priors_mult = 0;
  162. buf = data->buf; data->buf = 0;
  163. buf_count = data->buf_count; buf_size = data->buf_size;
  164. sample_count = data->sample_count;
  165. direction = data->direction; data->direction = 0;
  166. split_buf = data->split_buf; data->split_buf = 0;
  167. temp_storage = data->temp_storage; data->temp_storage = 0;
  168. nv_heap = data->nv_heap; cv_heap = data->cv_heap;
  169. data_root = new_node( 0, sample_count, 0, 0 );
  170. EXIT;
  171. }
  172. clear();
  173. var_all = 0;
  174. rng = &cv::theRNG();
  175. CV_CALL( set_params( _params ));
  176. // check parameter types and sizes
  177. CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
  178. train_data = _train_data;
  179. responses = _responses;
  180. if( _tflag == CV_ROW_SAMPLE )
  181. {
  182. ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
  183. dv_step = 1;
  184. if( _missing_mask )
  185. ms_step = _missing_mask->step, mv_step = 1;
  186. }
  187. else
  188. {
  189. dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
  190. ds_step = 1;
  191. if( _missing_mask )
  192. mv_step = _missing_mask->step, ms_step = 1;
  193. }
  194. tflag = _tflag;
  195. sample_count = sample_all;
  196. var_count = var_all;
  197. if( _sample_idx )
  198. {
  199. CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
  200. sidx = sample_indices->data.i;
  201. sample_count = sample_indices->rows + sample_indices->cols - 1;
  202. }
  203. if( _var_idx )
  204. {
  205. CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
  206. vidx = var_idx->data.i;
  207. var_count = var_idx->rows + var_idx->cols - 1;
  208. }
  209. is_buf_16u = false;
  210. if ( sample_count < 65536 )
  211. is_buf_16u = true;
  212. if( !CV_IS_MAT(_responses) ||
  213. (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
  214. CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
  215. (_responses->rows != 1 && _responses->cols != 1) ||
  216. _responses->rows + _responses->cols - 1 != sample_all )
  217. CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
  218. "floating-point vector containing as many elements as "
  219. "the total number of samples in the training data matrix" );
  220. r_type = CV_VAR_CATEGORICAL;
  221. if( _var_type )
  222. CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
  223. CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
  224. cat_var_count = 0;
  225. ord_var_count = -1;
  226. is_classifier = r_type == CV_VAR_CATEGORICAL;
  227. // step 0. calc the number of categorical vars
  228. for( vi = 0; vi < var_count; vi++ )
  229. {
  230. char vt = var_type0 ? var_type0->data.ptr[vi] : CV_VAR_ORDERED;
  231. var_type->data.i[vi] = vt == CV_VAR_CATEGORICAL ? cat_var_count++ : ord_var_count--;
  232. }
  233. ord_var_count = ~ord_var_count;
  234. cv_n = params.cv_folds;
  235. // set the two last elements of var_type array to be able
  236. // to locate responses and cross-validation labels using
  237. // the corresponding get_* functions.
  238. var_type->data.i[var_count] = cat_var_count;
  239. var_type->data.i[var_count+1] = cat_var_count+1;
  240. // in case of single ordered predictor we need dummy cv_labels
  241. // for safe split_node_data() operation
  242. have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
  243. work_var_count = var_count + (is_classifier ? 1 : 0) // for responses class_labels
  244. + (have_labels ? 1 : 0); // for cv_labels
  245. shared = _shared;
  246. buf_count = shared ? 2 : 1;
  247. buf_size = -1; // the member buf_size is obsolete
  248. effective_buf_size = (uint64)(work_var_count + 1)*(uint64)sample_count * buf_count; // this is the total size of "CvMat buf" to be allocated
  249. effective_buf_width = sample_count;
  250. effective_buf_height = work_var_count+1;
  251. if (effective_buf_width >= effective_buf_height)
  252. effective_buf_height *= buf_count;
  253. else
  254. effective_buf_width *= buf_count;
  255. if ((uint64)effective_buf_width * (uint64)effective_buf_height != effective_buf_size)
  256. {
  257. CV_Error(CV_StsBadArg, "The memory buffer cannot be allocated since its size exceeds integer fields limit");
  258. }
  259. if ( is_buf_16u )
  260. {
  261. CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_16UC1 ));
  262. CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
  263. }
  264. else
  265. {
  266. CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_32SC1 ));
  267. CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
  268. }
  269. size = is_classifier ? (cat_var_count+1) : cat_var_count;
  270. size = !size ? 1 : size;
  271. CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
  272. CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
  273. size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
  274. size = !size ? 1 : size;
  275. CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
  276. // now calculate the maximum size of split,
  277. // create memory storage that will keep nodes and splits of the decision tree
  278. // allocate root node and the buffer for the whole training data
  279. max_split_size = cvAlign(sizeof(CvDTreeSplit) +
  280. (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
  281. tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
  282. tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
  283. CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
  284. CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
  285. nv_size = var_count*sizeof(int);
  286. nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
  287. temp_block_size = nv_size;
  288. if( cv_n )
  289. {
  290. if( sample_count < cv_n*MAX(params.min_sample_count,10) )
  291. CV_ERROR( CV_StsOutOfRange,
  292. "The many folds in cross-validation for such a small dataset" );
  293. cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
  294. temp_block_size = MAX(temp_block_size, cv_size);
  295. }
  296. temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
  297. CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
  298. CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
  299. if( cv_size )
  300. CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
  301. CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
  302. max_c_count = 1;
  303. _fdst = 0;
  304. _idst = 0;
  305. if (ord_var_count)
  306. _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
  307. if (is_buf_16u && (cat_var_count || is_classifier))
  308. _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
  309. // transform the training data to convenient representation
  310. for( vi = 0; vi <= var_count; vi++ )
  311. {
  312. int ci;
  313. const uchar* mask = 0;
  314. int64 m_step = 0, step;
  315. const int* idata = 0;
  316. const float* fdata = 0;
  317. int num_valid = 0;
  318. if( vi < var_count ) // analyze i-th input variable
  319. {
  320. int vi0 = vidx ? vidx[vi] : vi;
  321. ci = get_var_type(vi);
  322. step = ds_step; m_step = ms_step;
  323. if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
  324. idata = _train_data->data.i + vi0*dv_step;
  325. else
  326. fdata = _train_data->data.fl + vi0*dv_step;
  327. if( _missing_mask )
  328. mask = _missing_mask->data.ptr + vi0*mv_step;
  329. }
  330. else // analyze _responses
  331. {
  332. ci = cat_var_count;
  333. step = CV_IS_MAT_CONT(_responses->type) ?
  334. 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
  335. if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
  336. idata = _responses->data.i;
  337. else
  338. fdata = _responses->data.fl;
  339. }
  340. if( (vi < var_count && ci>=0) ||
  341. (vi == var_count && is_classifier) ) // process categorical variable or response
  342. {
  343. int c_count, prev_label;
  344. int* c_map;
  345. if (is_buf_16u)
  346. udst = (unsigned short*)(buf->data.s + (size_t)vi*sample_count);
  347. else
  348. idst = buf->data.i + (size_t)vi*sample_count;
  349. // copy data
  350. for( i = 0; i < sample_count; i++ )
  351. {
  352. int val = INT_MAX, si = sidx ? sidx[i] : i;
  353. if( !mask || !mask[(size_t)si*m_step] )
  354. {
  355. if( idata )
  356. val = idata[(size_t)si*step];
  357. else
  358. {
  359. float t = fdata[(size_t)si*step];
  360. val = cvRound(t);
  361. if( fabs(t - val) > FLT_EPSILON )
  362. {
  363. sprintf( err, "%d-th value of %d-th (categorical) "
  364. "variable is not an integer", i, vi );
  365. CV_ERROR( CV_StsBadArg, err );
  366. }
  367. }
  368. if( val == INT_MAX )
  369. {
  370. sprintf( err, "%d-th value of %d-th (categorical) "
  371. "variable is too large", i, vi );
  372. CV_ERROR( CV_StsBadArg, err );
  373. }
  374. num_valid++;
  375. }
  376. if (is_buf_16u)
  377. {
  378. _idst[i] = val;
  379. pair16u32s_ptr[i].u = udst + i;
  380. pair16u32s_ptr[i].i = _idst + i;
  381. }
  382. else
  383. {
  384. idst[i] = val;
  385. int_ptr[i] = idst + i;
  386. }
  387. }
  388. c_count = num_valid > 0;
  389. if (is_buf_16u)
  390. {
  391. std::sort(pair16u32s_ptr, pair16u32s_ptr + sample_count, LessThanPairs());
  392. // count the categories
  393. for( i = 1; i < num_valid; i++ )
  394. if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
  395. c_count ++ ;
  396. }
  397. else
  398. {
  399. std::sort(int_ptr, int_ptr + sample_count, LessThanPtr<int>());
  400. // count the categories
  401. for( i = 1; i < num_valid; i++ )
  402. c_count += *int_ptr[i] != *int_ptr[i-1];
  403. }
  404. if( vi > 0 )
  405. max_c_count = MAX( max_c_count, c_count );
  406. cat_count->data.i[ci] = c_count;
  407. cat_ofs->data.i[ci] = total_c_count;
  408. // resize cat_map, if need
  409. if( cat_map->cols < total_c_count + c_count )
  410. {
  411. tmp_map = cat_map;
  412. CV_CALL( cat_map = cvCreateMat( 1,
  413. MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
  414. for( i = 0; i < total_c_count; i++ )
  415. cat_map->data.i[i] = tmp_map->data.i[i];
  416. cvReleaseMat( &tmp_map );
  417. }
  418. c_map = cat_map->data.i + total_c_count;
  419. total_c_count += c_count;
  420. c_count = -1;
  421. if (is_buf_16u)
  422. {
  423. // compact the class indices and build the map
  424. prev_label = ~*pair16u32s_ptr[0].i;
  425. for( i = 0; i < num_valid; i++ )
  426. {
  427. int cur_label = *pair16u32s_ptr[i].i;
  428. if( cur_label != prev_label )
  429. c_map[++c_count] = prev_label = cur_label;
  430. *pair16u32s_ptr[i].u = (unsigned short)c_count;
  431. }
  432. // replace labels for missing values with -1
  433. for( ; i < sample_count; i++ )
  434. *pair16u32s_ptr[i].u = 65535;
  435. }
  436. else
  437. {
  438. // compact the class indices and build the map
  439. prev_label = ~*int_ptr[0];
  440. for( i = 0; i < num_valid; i++ )
  441. {
  442. int cur_label = *int_ptr[i];
  443. if( cur_label != prev_label )
  444. c_map[++c_count] = prev_label = cur_label;
  445. *int_ptr[i] = c_count;
  446. }
  447. // replace labels for missing values with -1
  448. for( ; i < sample_count; i++ )
  449. *int_ptr[i] = -1;
  450. }
  451. }
  452. else if( ci < 0 ) // process ordered variable
  453. {
  454. if (is_buf_16u)
  455. udst = (unsigned short*)(buf->data.s + (size_t)vi*sample_count);
  456. else
  457. idst = buf->data.i + (size_t)vi*sample_count;
  458. for( i = 0; i < sample_count; i++ )
  459. {
  460. float val = ord_nan;
  461. int si = sidx ? sidx[i] : i;
  462. if( !mask || !mask[(size_t)si*m_step] )
  463. {
  464. if( idata )
  465. val = (float)idata[(size_t)si*step];
  466. else
  467. val = fdata[(size_t)si*step];
  468. if( fabs(val) >= ord_nan )
  469. {
  470. sprintf( err, "%d-th value of %d-th (ordered) "
  471. "variable (=%g) is too large", i, vi, val );
  472. CV_ERROR( CV_StsBadArg, err );
  473. }
  474. num_valid++;
  475. }
  476. if (is_buf_16u)
  477. udst[i] = (unsigned short)i; // TODO: memory corruption may be here
  478. else
  479. idst[i] = i;
  480. _fdst[i] = val;
  481. }
  482. if (is_buf_16u)
  483. std::sort(udst, udst + sample_count, LessThanIdx<float, unsigned short>(_fdst));
  484. else
  485. std::sort(idst, idst + sample_count, LessThanIdx<float, int>(_fdst));
  486. }
  487. if( vi < var_count )
  488. data_root->set_num_valid(vi, num_valid);
  489. }
  490. // set sample labels
  491. if (is_buf_16u)
  492. udst = (unsigned short*)(buf->data.s + (size_t)work_var_count*sample_count);
  493. else
  494. idst = buf->data.i + (size_t)work_var_count*sample_count;
  495. for (i = 0; i < sample_count; i++)
  496. {
  497. if (udst)
  498. udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
  499. else
  500. idst[i] = sidx ? sidx[i] : i;
  501. }
  502. if( cv_n )
  503. {
  504. unsigned short* usdst = 0;
  505. int* idst2 = 0;
  506. if (is_buf_16u)
  507. {
  508. usdst = (unsigned short*)(buf->data.s + (size_t)(get_work_var_count()-1)*sample_count);
  509. for( i = vi = 0; i < sample_count; i++ )
  510. {
  511. usdst[i] = (unsigned short)vi++;
  512. vi &= vi < cv_n ? -1 : 0;
  513. }
  514. for( i = 0; i < sample_count; i++ )
  515. {
  516. int a = (*rng)(sample_count);
  517. int b = (*rng)(sample_count);
  518. unsigned short unsh = (unsigned short)vi;
  519. CV_SWAP( usdst[a], usdst[b], unsh );
  520. }
  521. }
  522. else
  523. {
  524. idst2 = buf->data.i + (size_t)(get_work_var_count()-1)*sample_count;
  525. for( i = vi = 0; i < sample_count; i++ )
  526. {
  527. idst2[i] = vi++;
  528. vi &= vi < cv_n ? -1 : 0;
  529. }
  530. for( i = 0; i < sample_count; i++ )
  531. {
  532. int a = (*rng)(sample_count);
  533. int b = (*rng)(sample_count);
  534. CV_SWAP( idst2[a], idst2[b], vi );
  535. }
  536. }
  537. }
  538. if ( cat_map )
  539. cat_map->cols = MAX( total_c_count, 1 );
  540. max_split_size = cvAlign(sizeof(CvDTreeSplit) +
  541. (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
  542. CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
  543. have_priors = is_classifier && params.priors;
  544. if( is_classifier )
  545. {
  546. int m = get_num_classes();
  547. double sum = 0;
  548. CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
  549. for( i = 0; i < m; i++ )
  550. {
  551. double val = have_priors ? params.priors[i] : 1.;
  552. if( val <= 0 )
  553. CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
  554. priors->data.db[i] = val;
  555. sum += val;
  556. }
  557. // normalize weights
  558. if( have_priors )
  559. cvScale( priors, priors, 1./sum );
  560. CV_CALL( priors_mult = cvCloneMat( priors ));
  561. CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
  562. }
  563. CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
  564. CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
  565. __END__;
  566. if( data )
  567. delete data;
  568. if (_fdst)
  569. cvFree( &_fdst );
  570. if (_idst)
  571. cvFree( &_idst );
  572. cvFree( &int_ptr );
  573. cvFree( &pair16u32s_ptr);
  574. cvReleaseMat( &var_type0 );
  575. cvReleaseMat( &sample_indices );
  576. cvReleaseMat( &tmp_map );
  577. }
  578. void CvDTreeTrainData::do_responses_copy()
  579. {
  580. responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );
  581. cvCopy( responses, responses_copy);
  582. responses = responses_copy;
  583. }
  584. CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
  585. {
  586. CvDTreeNode* root = 0;
  587. CvMat* isubsample_idx = 0;
  588. CvMat* subsample_co = 0;
  589. bool isMakeRootCopy = true;
  590. CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
  591. __BEGIN__;
  592. if( !data_root )
  593. CV_ERROR( CV_StsError, "No training data has been set" );
  594. if( _subsample_idx )
  595. {
  596. CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
  597. if( isubsample_idx->cols + isubsample_idx->rows - 1 == sample_count )
  598. {
  599. const int* sidx = isubsample_idx->data.i;
  600. for( int i = 0; i < sample_count; i++ )
  601. {
  602. if( sidx[i] != i )
  603. {
  604. isMakeRootCopy = false;
  605. break;
  606. }
  607. }
  608. }
  609. else
  610. isMakeRootCopy = false;
  611. }
  612. if( isMakeRootCopy )
  613. {
  614. // make a copy of the root node
  615. CvDTreeNode temp;
  616. int i;
  617. root = new_node( 0, 1, 0, 0 );
  618. temp = *root;
  619. *root = *data_root;
  620. root->num_valid = temp.num_valid;
  621. if( root->num_valid )
  622. {
  623. for( i = 0; i < var_count; i++ )
  624. root->num_valid[i] = data_root->num_valid[i];
  625. }
  626. root->cv_Tn = temp.cv_Tn;
  627. root->cv_node_risk = temp.cv_node_risk;
  628. root->cv_node_error = temp.cv_node_error;
  629. }
  630. else
  631. {
  632. int* sidx = isubsample_idx->data.i;
  633. // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
  634. int* co, cur_ofs = 0;
  635. int vi, i;
  636. int workVarCount = get_work_var_count();
  637. int count = isubsample_idx->rows + isubsample_idx->cols - 1;
  638. root = new_node( 0, count, 1, 0 );
  639. CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
  640. cvZero( subsample_co );
  641. co = subsample_co->data.i;
  642. for( i = 0; i < count; i++ )
  643. co[sidx[i]*2]++;
  644. for( i = 0; i < sample_count; i++ )
  645. {
  646. if( co[i*2] )
  647. {
  648. co[i*2+1] = cur_ofs;
  649. cur_ofs += co[i*2];
  650. }
  651. else
  652. co[i*2+1] = -1;
  653. }
  654. cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
  655. for( vi = 0; vi < workVarCount; vi++ )
  656. {
  657. int ci = get_var_type(vi);
  658. if( ci >= 0 || vi >= var_count )
  659. {
  660. int num_valid = 0;
  661. const int* src = CvDTreeTrainData::get_cat_var_data(data_root, vi, (int*)inn_buf.data());
  662. if (is_buf_16u)
  663. {
  664. unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
  665. (size_t)vi*sample_count + root->offset);
  666. for( i = 0; i < count; i++ )
  667. {
  668. int val = src[sidx[i]];
  669. udst[i] = (unsigned short)val;
  670. num_valid += val >= 0;
  671. }
  672. }
  673. else
  674. {
  675. int* idst = buf->data.i + root->buf_idx*get_length_subbuf() +
  676. (size_t)vi*sample_count + root->offset;
  677. for( i = 0; i < count; i++ )
  678. {
  679. int val = src[sidx[i]];
  680. idst[i] = val;
  681. num_valid += val >= 0;
  682. }
  683. }
  684. if( vi < var_count )
  685. root->set_num_valid(vi, num_valid);
  686. }
  687. else
  688. {
  689. int *src_idx_buf = (int*)inn_buf.data();
  690. float *src_val_buf = (float*)(src_idx_buf + sample_count);
  691. int* sample_indices_buf = (int*)(src_val_buf + sample_count);
  692. const int* src_idx = 0;
  693. const float* src_val = 0;
  694. get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf );
  695. int j = 0, idx, count_i;
  696. int num_valid = data_root->get_num_valid(vi);
  697. if (is_buf_16u)
  698. {
  699. unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
  700. (size_t)vi*sample_count + data_root->offset);
  701. for( i = 0; i < num_valid; i++ )
  702. {
  703. idx = src_idx[i];
  704. count_i = co[idx*2];
  705. if( count_i )
  706. for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
  707. udst_idx[j] = (unsigned short)cur_ofs;
  708. }
  709. root->set_num_valid(vi, j);
  710. for( ; i < sample_count; i++ )
  711. {
  712. idx = src_idx[i];
  713. count_i = co[idx*2];
  714. if( count_i )
  715. for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
  716. udst_idx[j] = (unsigned short)cur_ofs;
  717. }
  718. }
  719. else
  720. {
  721. int* idst_idx = buf->data.i + root->buf_idx*get_length_subbuf() +
  722. (size_t)vi*sample_count + root->offset;
  723. for( i = 0; i < num_valid; i++ )
  724. {
  725. idx = src_idx[i];
  726. count_i = co[idx*2];
  727. if( count_i )
  728. for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
  729. idst_idx[j] = cur_ofs;
  730. }
  731. root->set_num_valid(vi, j);
  732. for( ; i < sample_count; i++ )
  733. {
  734. idx = src_idx[i];
  735. count_i = co[idx*2];
  736. if( count_i )
  737. for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
  738. idst_idx[j] = cur_ofs;
  739. }
  740. }
  741. }
  742. }
  743. // sample indices subsampling
  744. const int* sample_idx_src = get_sample_indices(data_root, (int*)inn_buf.data());
  745. if (is_buf_16u)
  746. {
  747. unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
  748. (size_t)workVarCount*sample_count + root->offset);
  749. for (i = 0; i < count; i++)
  750. sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
  751. }
  752. else
  753. {
  754. int* sample_idx_dst = buf->data.i + root->buf_idx*get_length_subbuf() +
  755. (size_t)workVarCount*sample_count + root->offset;
  756. for (i = 0; i < count; i++)
  757. sample_idx_dst[i] = sample_idx_src[sidx[i]];
  758. }
  759. }
  760. __END__;
  761. cvReleaseMat( &isubsample_idx );
  762. cvReleaseMat( &subsample_co );
  763. return root;
  764. }
  765. void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
  766. float* values, uchar* missing,
  767. float* _responses, bool get_class_idx )
  768. {
  769. CvMat* subsample_idx = 0;
  770. CvMat* subsample_co = 0;
  771. CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
  772. __BEGIN__;
  773. int i, vi, total = sample_count, count = total, cur_ofs = 0;
  774. int* sidx = 0;
  775. int* co = 0;
  776. cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
  777. if( _subsample_idx )
  778. {
  779. CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
  780. sidx = subsample_idx->data.i;
  781. CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
  782. co = subsample_co->data.i;
  783. cvZero( subsample_co );
  784. count = subsample_idx->cols + subsample_idx->rows - 1;
  785. for( i = 0; i < count; i++ )
  786. co[sidx[i]*2]++;
  787. for( i = 0; i < total; i++ )
  788. {
  789. int count_i = co[i*2];
  790. if( count_i )
  791. {
  792. co[i*2+1] = cur_ofs*var_count;
  793. cur_ofs += count_i;
  794. }
  795. }
  796. }
  797. if( missing )
  798. memset( missing, 1, count*var_count );
  799. for( vi = 0; vi < var_count; vi++ )
  800. {
  801. int ci = get_var_type(vi);
  802. if( ci >= 0 ) // categorical
  803. {
  804. float* dst = values + vi;
  805. uchar* m = missing ? missing + vi : 0;
  806. const int* src = get_cat_var_data(data_root, vi, (int*)inn_buf.data());
  807. for( i = 0; i < count; i++, dst += var_count )
  808. {
  809. int idx = sidx ? sidx[i] : i;
  810. int val = src[idx];
  811. *dst = (float)val;
  812. if( m )
  813. {
  814. *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
  815. m += var_count;
  816. }
  817. }
  818. }
  819. else // ordered
  820. {
  821. float* dst = values + vi;
  822. uchar* m = missing ? missing + vi : 0;
  823. int count1 = data_root->get_num_valid(vi);
  824. float *src_val_buf = (float*)inn_buf.data();
  825. int* src_idx_buf = (int*)(src_val_buf + sample_count);
  826. int* sample_indices_buf = src_idx_buf + sample_count;
  827. const float *src_val = 0;
  828. const int* src_idx = 0;
  829. get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf);
  830. for( i = 0; i < count1; i++ )
  831. {
  832. int idx = src_idx[i];
  833. int count_i = 1;
  834. if( co )
  835. {
  836. count_i = co[idx*2];
  837. cur_ofs = co[idx*2+1];
  838. }
  839. else
  840. cur_ofs = idx*var_count;
  841. if( count_i )
  842. {
  843. float val = src_val[i];
  844. for( ; count_i > 0; count_i--, cur_ofs += var_count )
  845. {
  846. dst[cur_ofs] = val;
  847. if( m )
  848. m[cur_ofs] = 0;
  849. }
  850. }
  851. }
  852. }
  853. }
  854. // copy responses
  855. if( _responses )
  856. {
  857. if( is_classifier )
  858. {
  859. const int* src = get_class_labels(data_root, (int*)inn_buf.data());
  860. for( i = 0; i < count; i++ )
  861. {
  862. int idx = sidx ? sidx[i] : i;
  863. int val = get_class_idx ? src[idx] :
  864. cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
  865. _responses[i] = (float)val;
  866. }
  867. }
  868. else
  869. {
  870. float* val_buf = (float*)inn_buf.data();
  871. int* sample_idx_buf = (int*)(val_buf + sample_count);
  872. const float* _values = get_ord_responses(data_root, val_buf, sample_idx_buf);
  873. for( i = 0; i < count; i++ )
  874. {
  875. int idx = sidx ? sidx[i] : i;
  876. _responses[i] = _values[idx];
  877. }
  878. }
  879. }
  880. __END__;
  881. cvReleaseMat( &subsample_idx );
  882. cvReleaseMat( &subsample_co );
  883. }
  884. CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
  885. int storage_idx, int offset )
  886. {
  887. CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
  888. node->sample_count = count;
  889. node->depth = parent ? parent->depth + 1 : 0;
  890. node->parent = parent;
  891. node->left = node->right = 0;
  892. node->split = 0;
  893. node->value = 0;
  894. node->class_idx = 0;
  895. node->maxlr = 0.;
  896. node->buf_idx = storage_idx;
  897. node->offset = offset;
  898. if( nv_heap )
  899. node->num_valid = (int*)cvSetNew( nv_heap );
  900. else
  901. node->num_valid = 0;
  902. node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
  903. node->complexity = 0;
  904. if( params.cv_folds > 0 && cv_heap )
  905. {
  906. int cv_n = params.cv_folds;
  907. node->Tn = INT_MAX;
  908. node->cv_Tn = (int*)cvSetNew( cv_heap );
  909. node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
  910. node->cv_node_error = node->cv_node_risk + cv_n;
  911. }
  912. else
  913. {
  914. node->Tn = 0;
  915. node->cv_Tn = 0;
  916. node->cv_node_risk = 0;
  917. node->cv_node_error = 0;
  918. }
  919. return node;
  920. }
  921. CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
  922. int split_point, int inversed, float quality )
  923. {
  924. CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
  925. split->var_idx = vi;
  926. split->condensed_idx = INT_MIN;
  927. split->ord.c = cmp_val;
  928. split->ord.split_point = split_point;
  929. split->inversed = inversed;
  930. split->quality = quality;
  931. split->next = 0;
  932. return split;
  933. }
  934. CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
  935. {
  936. CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
  937. int i, n = (max_c_count + 31)/32;
  938. split->var_idx = vi;
  939. split->condensed_idx = INT_MIN;
  940. split->inversed = 0;
  941. split->quality = quality;
  942. for( i = 0; i < n; i++ )
  943. split->subset[i] = 0;
  944. split->next = 0;
  945. return split;
  946. }
  947. void CvDTreeTrainData::free_node( CvDTreeNode* node )
  948. {
  949. CvDTreeSplit* split = node->split;
  950. free_node_data( node );
  951. while( split )
  952. {
  953. CvDTreeSplit* next = split->next;
  954. cvSetRemoveByPtr( split_heap, split );
  955. split = next;
  956. }
  957. node->split = 0;
  958. cvSetRemoveByPtr( node_heap, node );
  959. }
  960. void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
  961. {
  962. if( node->num_valid )
  963. {
  964. cvSetRemoveByPtr( nv_heap, node->num_valid );
  965. node->num_valid = 0;
  966. }
  967. // do not free cv_* fields, as all the cross-validation related data is released at once.
  968. }
  969. void CvDTreeTrainData::free_train_data()
  970. {
  971. cvReleaseMat( &counts );
  972. cvReleaseMat( &buf );
  973. cvReleaseMat( &direction );
  974. cvReleaseMat( &split_buf );
  975. cvReleaseMemStorage( &temp_storage );
  976. cvReleaseMat( &responses_copy );
  977. cv_heap = nv_heap = 0;
  978. }
  979. void CvDTreeTrainData::clear()
  980. {
  981. free_train_data();
  982. cvReleaseMemStorage( &tree_storage );
  983. cvReleaseMat( &var_idx );
  984. cvReleaseMat( &var_type );
  985. cvReleaseMat( &cat_count );
  986. cvReleaseMat( &cat_ofs );
  987. cvReleaseMat( &cat_map );
  988. cvReleaseMat( &priors );
  989. cvReleaseMat( &priors_mult );
  990. node_heap = split_heap = 0;
  991. sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
  992. have_labels = have_priors = is_classifier = false;
  993. buf_count = buf_size = 0;
  994. shared = false;
  995. data_root = 0;
  996. rng = &cv::theRNG();
  997. }
  998. int CvDTreeTrainData::get_num_classes() const
  999. {
  1000. return is_classifier ? cat_count->data.i[cat_var_count] : 0;
  1001. }
  1002. int CvDTreeTrainData::get_var_type(int vi) const
  1003. {
  1004. return var_type->data.i[vi];
  1005. }
  1006. void CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
  1007. const float** ord_values, const int** sorted_indices, int* sample_indices_buf )
  1008. {
  1009. int vidx = var_idx ? var_idx->data.i[vi] : vi;
  1010. int node_sample_count = n->sample_count;
  1011. int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
  1012. const int* sample_indices = get_sample_indices(n, sample_indices_buf);
  1013. if( !is_buf_16u )
  1014. *sorted_indices = buf->data.i + n->buf_idx*get_length_subbuf() +
  1015. (size_t)vi*sample_count + n->offset;
  1016. else {
  1017. const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() +
  1018. (size_t)vi*sample_count + n->offset );
  1019. for( int i = 0; i < node_sample_count; i++ )
  1020. sorted_indices_buf[i] = short_indices[i];
  1021. *sorted_indices = sorted_indices_buf;
  1022. }
  1023. if( tflag == CV_ROW_SAMPLE )
  1024. {
  1025. for( int i = 0; i < node_sample_count &&
  1026. ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
  1027. {
  1028. int idx = (*sorted_indices)[i];
  1029. idx = sample_indices[idx];
  1030. ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
  1031. }
  1032. }
  1033. else
  1034. for( int i = 0; i < node_sample_count &&
  1035. ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
  1036. {
  1037. int idx = (*sorted_indices)[i];
  1038. idx = sample_indices[idx];
  1039. ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
  1040. }
  1041. *ord_values = ord_values_buf;
  1042. }
  1043. const int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf )
  1044. {
  1045. if (is_classifier)
  1046. return get_cat_var_data( n, var_count, labels_buf);
  1047. return 0;
  1048. }
  1049. const int* CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf )
  1050. {
  1051. return get_cat_var_data( n, get_work_var_count(), indices_buf );
  1052. }
  1053. const float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, int*sample_indices_buf )
  1054. {
  1055. int _sample_count = n->sample_count;
  1056. int r_step = CV_IS_MAT_CONT(responses->type) ? 1 : responses->step/CV_ELEM_SIZE(responses->type);
  1057. const int* indices = get_sample_indices(n, sample_indices_buf);
  1058. for( int i = 0; i < _sample_count &&
  1059. (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )
  1060. {
  1061. int idx = indices[i];
  1062. values_buf[i] = *(responses->data.fl + idx * r_step);
  1063. }
  1064. return values_buf;
  1065. }
  1066. const int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf )
  1067. {
  1068. if (have_labels)
  1069. return get_cat_var_data( n, get_work_var_count()- 1, labels_buf);
  1070. return 0;
  1071. }
  1072. const int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf)
  1073. {
  1074. const int* cat_values = 0;
  1075. if( !is_buf_16u )
  1076. cat_values = buf->data.i + n->buf_idx*get_length_subbuf() +
  1077. (size_t)vi*sample_count + n->offset;
  1078. else {
  1079. const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() +
  1080. (size_t)vi*sample_count + n->offset);
  1081. for( int i = 0; i < n->sample_count; i++ )
  1082. cat_values_buf[i] = short_values[i];
  1083. cat_values = cat_values_buf;
  1084. }
  1085. return cat_values;
  1086. }
  1087. int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
  1088. {
  1089. int idx = n->buf_idx + 1;
  1090. if( idx >= buf_count )
  1091. idx = shared ? 1 : 0;
  1092. return idx;
  1093. }
  1094. void CvDTreeTrainData::write_params( CvFileStorage* fs ) const
  1095. {
  1096. CV_FUNCNAME( "CvDTreeTrainData::write_params" );
  1097. __BEGIN__;
  1098. int vi, vcount = var_count;
  1099. cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
  1100. cvWriteInt( fs, "var_all", var_all );
  1101. cvWriteInt( fs, "var_count", var_count );
  1102. cvWriteInt( fs, "ord_var_count", ord_var_count );
  1103. cvWriteInt( fs, "cat_var_count", cat_var_count );
  1104. cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
  1105. cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
  1106. if( is_classifier )
  1107. {
  1108. cvWriteInt( fs, "max_categories", params.max_categories );
  1109. }
  1110. else
  1111. {
  1112. cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
  1113. }
  1114. cvWriteInt( fs, "max_depth", params.max_depth );
  1115. cvWriteInt( fs, "min_sample_count", params.min_sample_count );
  1116. cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
  1117. if( params.cv_folds > 1 )
  1118. {
  1119. cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
  1120. cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
  1121. }
  1122. if( priors )
  1123. cvWrite( fs, "priors", priors );
  1124. cvEndWriteStruct( fs );
  1125. if( var_idx )
  1126. cvWrite( fs, "var_idx", var_idx );
  1127. cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
  1128. for( vi = 0; vi < vcount; vi++ )
  1129. cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
  1130. cvEndWriteStruct( fs );
  1131. if( cat_count && (cat_var_count > 0 || is_classifier) )
  1132. {
  1133. CV_ASSERT( cat_count != 0 );
  1134. cvWrite( fs, "cat_count", cat_count );
  1135. cvWrite( fs, "cat_map", cat_map );
  1136. }
  1137. __END__;
  1138. }
  1139. void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
  1140. {
  1141. CV_FUNCNAME( "CvDTreeTrainData::read_params" );
  1142. __BEGIN__;
  1143. CvFileNode *tparams_node, *vartype_node;
  1144. CvSeqReader reader;
  1145. int vi, max_split_size, tree_block_size;
  1146. is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
  1147. var_all = cvReadIntByName( fs, node, "var_all" );
  1148. var_count = cvReadIntByName( fs, node, "var_count", var_all );
  1149. cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
  1150. ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
  1151. tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
  1152. if( tparams_node ) // training parameters are not necessary
  1153. {
  1154. params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
  1155. if( is_classifier )
  1156. {
  1157. params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
  1158. }
  1159. else
  1160. {
  1161. params.regression_accuracy =
  1162. (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
  1163. }
  1164. params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
  1165. params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
  1166. params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
  1167. if( params.cv_folds > 1 )
  1168. {
  1169. params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
  1170. params.truncate_pruned_tree =
  1171. cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
  1172. }
  1173. priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
  1174. if( priors )
  1175. {
  1176. if( !CV_IS_MAT(priors) )
  1177. CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
  1178. priors_mult = cvCloneMat( priors );
  1179. }
  1180. }
  1181. CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
  1182. if( var_idx )
  1183. {
  1184. if( !CV_IS_MAT(var_idx) ||
  1185. (var_idx->cols != 1 && var_idx->rows != 1) ||
  1186. var_idx->cols + var_idx->rows - 1 != var_count ||
  1187. CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
  1188. CV_ERROR( CV_StsParseError,
  1189. "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
  1190. for( vi = 0; vi < var_count; vi++ )
  1191. if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
  1192. CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
  1193. }
  1194. ////// read var type
  1195. CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
  1196. cat_var_count = 0;
  1197. ord_var_count = -1;
  1198. vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
  1199. if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
  1200. var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
  1201. else
  1202. {
  1203. if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
  1204. vartype_node->data.seq->total != var_count )
  1205. CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
  1206. cvStartReadSeq( vartype_node->data.seq, &reader );
  1207. for( vi = 0; vi < var_count; vi++ )
  1208. {
  1209. CvFileNode* n = (CvFileNode*)reader.ptr;
  1210. if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
  1211. CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
  1212. var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
  1213. CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
  1214. }
  1215. }
  1216. var_type->data.i[var_count] = cat_var_count;
  1217. ord_var_count = ~ord_var_count;
  1218. //////
  1219. if( cat_var_count > 0 || is_classifier )
  1220. {
  1221. int ccount, total_c_count = 0;
  1222. CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
  1223. CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
  1224. if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
  1225. (cat_count->cols != 1 && cat_count->rows != 1) ||
  1226. CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
  1227. cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
  1228. (cat_map->cols != 1 && cat_map->rows != 1) ||
  1229. CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
  1230. CV_ERROR( CV_StsParseError,
  1231. "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
  1232. ccount = cat_var_count + is_classifier;
  1233. CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
  1234. cat_ofs->data.i[0] = 0;
  1235. max_c_count = 1;
  1236. for( vi = 0; vi < ccount; vi++ )
  1237. {
  1238. int val = cat_count->data.i[vi];
  1239. if( val <= 0 )
  1240. CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
  1241. max_c_count = MAX( max_c_count, val );
  1242. cat_ofs->data.i[vi+1] = total_c_count += val;
  1243. }
  1244. if( cat_map->cols + cat_map->rows - 1 != total_c_count )
  1245. CV_ERROR( CV_StsBadSize,
  1246. "cat_map vector length is not equal to the total number of categories in all categorical vars" );
  1247. }
  1248. max_split_size = cvAlign(sizeof(CvDTreeSplit) +
  1249. (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
  1250. tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
  1251. tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
  1252. CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
  1253. CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
  1254. sizeof(CvDTreeNode), tree_storage ));
  1255. CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
  1256. max_split_size, tree_storage ));
  1257. __END__;
  1258. }
  1259. /////////////////////// Decision Tree /////////////////////////
  1260. CvDTreeParams::CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
  1261. cv_folds(10), use_surrogates(true), use_1se_rule(true),
  1262. truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
  1263. {}
  1264. CvDTreeParams::CvDTreeParams( int _max_depth, int _min_sample_count,
  1265. float _regression_accuracy, bool _use_surrogates,
  1266. int _max_categories, int _cv_folds,
  1267. bool _use_1se_rule, bool _truncate_pruned_tree,
  1268. const float* _priors ) :
  1269. max_categories(_max_categories), max_depth(_max_depth),
  1270. min_sample_count(_min_sample_count), cv_folds (_cv_folds),
  1271. use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
  1272. truncate_pruned_tree(_truncate_pruned_tree),
  1273. regression_accuracy(_regression_accuracy),
  1274. priors(_priors)
  1275. {}
  1276. CvDTree::CvDTree()
  1277. {
  1278. data = 0;
  1279. var_importance = 0;
  1280. default_model_name = "my_tree";
  1281. clear();
  1282. }
  1283. void CvDTree::clear()
  1284. {
  1285. cvReleaseMat( &var_importance );
  1286. if( data )
  1287. {
  1288. if( !data->shared )
  1289. delete data;
  1290. else
  1291. free_tree();
  1292. data = 0;
  1293. }
  1294. root = 0;
  1295. pruned_tree_idx = -1;
  1296. }
  1297. CvDTree::~CvDTree()
  1298. {
  1299. clear();
  1300. }
  1301. const CvDTreeNode* CvDTree::get_root() const
  1302. {
  1303. return root;
  1304. }
  1305. int CvDTree::get_pruned_tree_idx() const
  1306. {
  1307. return pruned_tree_idx;
  1308. }
  1309. CvDTreeTrainData* CvDTree::get_data()
  1310. {
  1311. return data;
  1312. }
  1313. bool CvDTree::train( const CvMat* _train_data, int _tflag,
  1314. const CvMat* _responses, const CvMat* _var_idx,
  1315. const CvMat* _sample_idx, const CvMat* _var_type,
  1316. const CvMat* _missing_mask, CvDTreeParams _params )
  1317. {
  1318. bool result = false;
  1319. CV_FUNCNAME( "CvDTree::train" );
  1320. __BEGIN__;
  1321. clear();
  1322. data = new CvDTreeTrainData( _train_data, _tflag, _responses,
  1323. _var_idx, _sample_idx, _var_type,
  1324. _missing_mask, _params, false );
  1325. CV_CALL( result = do_train(0) );
  1326. __END__;
  1327. return result;
  1328. }
  1329. bool CvDTree::train( const Mat& _train_data, int _tflag,
  1330. const Mat& _responses, const Mat& _var_idx,
  1331. const Mat& _sample_idx, const Mat& _var_type,
  1332. const Mat& _missing_mask, CvDTreeParams _params )
  1333. {
  1334. train_data_hdr = cvMat(_train_data);
  1335. train_data_mat = _train_data;
  1336. responses_hdr = cvMat(_responses);
  1337. responses_mat = _responses;
  1338. CvMat vidx=cvMat(_var_idx), sidx=cvMat(_sample_idx), vtype=cvMat(_var_type), mmask=cvMat(_missing_mask);
  1339. return train(&train_data_hdr, _tflag, &responses_hdr, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0,
  1340. vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params);
  1341. }
  1342. bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )
  1343. {
  1344. bool result = false;
  1345. CV_FUNCNAME( "CvDTree::train" );
  1346. __BEGIN__;
  1347. const CvMat* values = _data->get_values();
  1348. const CvMat* response = _data->get_responses();
  1349. const CvMat* missing = _data->get_missing();
  1350. const CvMat* var_types = _data->get_var_types();
  1351. const CvMat* train_sidx = _data->get_train_sample_idx();
  1352. const CvMat* var_idx = _data->get_var_idx();
  1353. CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
  1354. train_sidx, var_types, missing, _params ) );
  1355. __END__;
  1356. return result;
  1357. }
  1358. bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
  1359. {
  1360. bool result = false;
  1361. CV_FUNCNAME( "CvDTree::train" );
  1362. __BEGIN__;
  1363. clear();
  1364. data = _data;
  1365. data->shared = true;
  1366. CV_CALL( result = do_train(_subsample_idx));
  1367. __END__;
  1368. return result;
  1369. }
  1370. bool CvDTree::do_train( const CvMat* _subsample_idx )
  1371. {
  1372. bool result = false;
  1373. CV_FUNCNAME( "CvDTree::do_train" );
  1374. __BEGIN__;
  1375. root = data->subsample_data( _subsample_idx );
  1376. CV_CALL( try_split_node(root));
  1377. if( root->split )
  1378. {
  1379. CV_Assert( root->left );
  1380. CV_Assert( root->right );
  1381. if( data->params.cv_folds > 0 )
  1382. CV_CALL( prune_cv() );
  1383. if( !data->shared )
  1384. data->free_train_data();
  1385. result = true;
  1386. }
  1387. __END__;
  1388. return result;
  1389. }
  1390. void CvDTree::try_split_node( CvDTreeNode* node )
  1391. {
  1392. CvDTreeSplit* best_split = 0;
  1393. int i, n = node->sample_count, vi;
  1394. bool can_split = true;
  1395. double quality_scale;
  1396. calc_node_value( node );
  1397. if( node->sample_count <= data->params.min_sample_count ||
  1398. node->depth >= data->params.max_depth )
  1399. can_split = false;
  1400. if( can_split && data->is_classifier )
  1401. {
  1402. // check if we have a "pure" node,
  1403. // we assume that cls_count is filled by calc_node_value()
  1404. int* cls_count = data->counts->data.i;
  1405. int nz = 0, m = data->get_num_classes();
  1406. for( i = 0; i < m; i++ )
  1407. nz += cls_count[i] != 0;
  1408. if( nz == 1 ) // there is only one class
  1409. can_split = false;
  1410. }
  1411. else if( can_split )
  1412. {
  1413. if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
  1414. can_split = false;
  1415. }
  1416. if( can_split )
  1417. {
  1418. best_split = find_best_split(node);
  1419. // TODO: check the split quality ...
  1420. node->split = best_split;
  1421. }
  1422. if( !can_split || !best_split )
  1423. {
  1424. data->free_node_data(node);
  1425. return;
  1426. }
  1427. quality_scale = calc_node_dir( node );
  1428. if( data->params.use_surrogates )
  1429. {
  1430. // find all the surrogate splits
  1431. // and sort them by their similarity to the primary one
  1432. for( vi = 0; vi < data->var_count; vi++ )
  1433. {
  1434. CvDTreeSplit* split;
  1435. int ci = data->get_var_type(vi);
  1436. if( vi == best_split->var_idx )
  1437. continue;
  1438. if( ci >= 0 )
  1439. split = find_surrogate_split_cat( node, vi );
  1440. else
  1441. split = find_surrogate_split_ord( node, vi );
  1442. if( split )
  1443. {
  1444. // insert the split
  1445. CvDTreeSplit* prev_split = node->split;
  1446. split->quality = (float)(split->quality*quality_scale);
  1447. while( prev_split->next &&
  1448. prev_split->next->quality > split->quality )
  1449. prev_split = prev_split->next;
  1450. split->next = prev_split->next;
  1451. prev_split->next = split;
  1452. }
  1453. }
  1454. }
  1455. split_node_data( node );
  1456. try_split_node( node->left );
  1457. try_split_node( node->right );
  1458. }
  1459. // calculate direction (left(-1),right(1),missing(0))
  1460. // for each sample using the best split
  1461. // the function returns scale coefficients for surrogate split quality factors.
  1462. // the scale is applied to normalize surrogate split quality relatively to the
  1463. // best (primary) split quality. That is, if a surrogate split is absolutely
  1464. // identical to the primary split, its quality will be set to the maximum value =
  1465. // quality of the primary split; otherwise, it will be lower.
  1466. // besides, the function compute node->maxlr,
  1467. // minimum possible quality (w/o considering the above mentioned scale)
  1468. // for a surrogate split. Surrogate splits with quality less than node->maxlr
  1469. // are not discarded.
  1470. double CvDTree::calc_node_dir( CvDTreeNode* node )
  1471. {
  1472. char* dir = (char*)data->direction->data.ptr;
  1473. int i, n = node->sample_count, vi = node->split->var_idx;
  1474. double L, R;
  1475. assert( !node->split->inversed );
  1476. if( data->get_var_type(vi) >= 0 ) // split on categorical var
  1477. {
  1478. cv::AutoBuffer<int> inn_buf(n*(!data->have_priors ? 1 : 2));
  1479. int* labels_buf = inn_buf.data();
  1480. const int* labels = data->get_cat_var_data( node, vi, labels_buf );
  1481. const int* subset = node->split->subset;
  1482. if( !data->have_priors )
  1483. {
  1484. int sum = 0, sum_abs = 0;
  1485. for( i = 0; i < n; i++ )
  1486. {
  1487. int idx = labels[i];
  1488. int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
  1489. CV_DTREE_CAT_DIR(idx,subset) : 0;
  1490. sum += d; sum_abs += d & 1;
  1491. dir[i] = (char)d;
  1492. }
  1493. R = (sum_abs + sum) >> 1;
  1494. L = (sum_abs - sum) >> 1;
  1495. }
  1496. else
  1497. {
  1498. const double* priors = data->priors_mult->data.db;
  1499. double sum = 0, sum_abs = 0;
  1500. int* responses_buf = labels_buf + n;
  1501. const int* responses = data->get_class_labels(node, responses_buf);
  1502. for( i = 0; i < n; i++ )
  1503. {
  1504. int idx = labels[i];
  1505. double w = priors[responses[i]];
  1506. int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
  1507. sum += d*w; sum_abs += (d & 1)*w;
  1508. dir[i] = (char)d;
  1509. }
  1510. R = (sum_abs + sum) * 0.5;
  1511. L = (sum_abs - sum) * 0.5;
  1512. }
  1513. }
  1514. else // split on ordered var
  1515. {
  1516. int split_point = node->split->ord.split_point;
  1517. int n1 = node->get_num_valid(vi);
  1518. cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)));
  1519. float* val_buf = (float*)inn_buf.data();
  1520. int* sorted_buf = (int*)(val_buf + n);
  1521. int* sample_idx_buf = sorted_buf + n;
  1522. const float* val = 0;
  1523. const int* sorted = 0;
  1524. data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted, sample_idx_buf);
  1525. assert( 0 <= split_point && split_point < n1-1 );
  1526. if( !data->have_priors )
  1527. {
  1528. for( i = 0; i <= split_point; i++ )
  1529. dir[sorted[i]] = (char)-1;
  1530. for( ; i < n1; i++ )
  1531. dir[sorted[i]] = (char)1;
  1532. for( ; i < n; i++ )
  1533. dir[sorted[i]] = (char)0;
  1534. L = split_point-1;
  1535. R = n1 - split_point + 1;
  1536. }
  1537. else
  1538. {
  1539. const double* priors = data->priors_mult->data.db;
  1540. int* responses_buf = sample_idx_buf + n;
  1541. const int* responses = data->get_class_labels(node, responses_buf);
  1542. L = R = 0;
  1543. for( i = 0; i <= split_point; i++ )
  1544. {
  1545. int idx = sorted[i];
  1546. double w = priors[responses[idx]];
  1547. dir[idx] = (char)-1;
  1548. L += w;
  1549. }
  1550. for( ; i < n1; i++ )
  1551. {
  1552. int idx = sorted[i];
  1553. double w = priors[responses[idx]];
  1554. dir[idx] = (char)1;
  1555. R += w;
  1556. }
  1557. for( ; i < n; i++ )
  1558. dir[sorted[i]] = (char)0;
  1559. }
  1560. }
  1561. node->maxlr = MAX( L, R );
  1562. return node->split->quality/(L + R);
  1563. }
  1564. namespace cv
  1565. {
  1566. void DefaultDeleter<CvDTreeSplit>::operator ()(CvDTreeSplit* obj) const { fastFree(obj); }
  1567. DTreeBestSplitFinder::DTreeBestSplitFinder( CvDTree* _tree, CvDTreeNode* _node)
  1568. {
  1569. tree = _tree;
  1570. node = _node;
  1571. splitSize = tree->get_data()->split_heap->elem_size;
  1572. bestSplit.reset((CvDTreeSplit*)fastMalloc(splitSize));
  1573. memset(bestSplit.get(), 0, splitSize);
  1574. bestSplit->quality = -1;
  1575. bestSplit->condensed_idx = INT_MIN;
  1576. split.reset((CvDTreeSplit*)fastMalloc(splitSize));
  1577. memset(split.get(), 0, splitSize);
  1578. //haveSplit = false;
  1579. }
  1580. DTreeBestSplitFinder::DTreeBestSplitFinder( const DTreeBestSplitFinder& finder, Split )
  1581. {
  1582. tree = finder.tree;
  1583. node = finder.node;
  1584. splitSize = tree->get_data()->split_heap->elem_size;
  1585. bestSplit.reset((CvDTreeSplit*)fastMalloc(splitSize));
  1586. memcpy(bestSplit.get(), finder.bestSplit.get(), splitSize);
  1587. split.reset((CvDTreeSplit*)fastMalloc(splitSize));
  1588. memset(split.get(), 0, splitSize);
  1589. }
  1590. void DTreeBestSplitFinder::operator()(const BlockedRange& range)
  1591. {
  1592. int vi, vi1 = range.begin(), vi2 = range.end();
  1593. int n = node->sample_count;
  1594. CvDTreeTrainData* data = tree->get_data();
  1595. AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float)));
  1596. for( vi = vi1; vi < vi2; vi++ )
  1597. {
  1598. CvDTreeSplit *res;
  1599. int ci = data->get_var_type(vi);
  1600. if( node->get_num_valid(vi) <= 1 )
  1601. continue;
  1602. if( data->is_classifier )
  1603. {
  1604. if( ci >= 0 )
  1605. res = tree->find_split_cat_class( node, vi, bestSplit->quality, split, inn_buf.data() );
  1606. else
  1607. res = tree->find_split_ord_class( node, vi, bestSplit->quality, split, inn_buf.data() );
  1608. }
  1609. else
  1610. {
  1611. if( ci >= 0 )
  1612. res = tree->find_split_cat_reg( node, vi, bestSplit->quality, split, inn_buf.data() );
  1613. else
  1614. res = tree->find_split_ord_reg( node, vi, bestSplit->quality, split, inn_buf.data() );
  1615. }
  1616. if( res && bestSplit->quality < split->quality )
  1617. memcpy( bestSplit.get(), split.get(), splitSize );
  1618. }
  1619. }
  1620. void DTreeBestSplitFinder::join( DTreeBestSplitFinder& rhs )
  1621. {
  1622. if( bestSplit->quality < rhs.bestSplit->quality )
  1623. memcpy( bestSplit.get(), rhs.bestSplit.get(), splitSize );
  1624. }
  1625. }
  1626. CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
  1627. {
  1628. DTreeBestSplitFinder finder( this, node );
  1629. cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder);
  1630. CvDTreeSplit *bestSplit = 0;
  1631. if( finder.bestSplit->quality > 0 )
  1632. {
  1633. bestSplit = data->new_split_cat( 0, -1.0f );
  1634. memcpy( bestSplit, finder.bestSplit, finder.splitSize );
  1635. }
  1636. return bestSplit;
  1637. }
  1638. CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi,
  1639. float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
  1640. {
  1641. const float epsilon = FLT_EPSILON*2;
  1642. int n = node->sample_count;
  1643. int n1 = node->get_num_valid(vi);
  1644. int m = data->get_num_classes();
  1645. int base_size = 2*m*sizeof(int);
  1646. cv::AutoBuffer<uchar> inn_buf(base_size);
  1647. if( !_ext_buf )
  1648. inn_buf.allocate(base_size + n*(3*sizeof(int)+sizeof(float)));
  1649. uchar* base_buf = inn_buf.data();
  1650. uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
  1651. float* values_buf = (float*)ext_buf;
  1652. int* sorted_indices_buf = (int*)(values_buf + n);
  1653. int* sample_indices_buf = sorted_indices_buf + n;
  1654. const float* values = 0;
  1655. const int* sorted_indices = 0;
  1656. data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values,
  1657. &sorted_indices, sample_indices_buf );
  1658. int* responses_buf = sample_indices_buf + n;
  1659. const int* responses = data->get_class_labels( node, responses_buf );
  1660. const int* rc0 = data->counts->data.i;
  1661. int* lc = (int*)base_buf;
  1662. int* rc = lc + m;
  1663. int i, best_i = -1;
  1664. double lsum2 = 0, rsum2 = 0, best_val = init_quality;
  1665. const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
  1666. // init arrays of class instance counters on both sides of the split
  1667. for( i = 0; i < m; i++ )
  1668. {
  1669. lc[i] = 0;
  1670. rc[i] = rc0[i];
  1671. }
  1672. // compensate for missing values
  1673. for( i = n1; i < n; i++ )
  1674. {
  1675. rc[responses[sorted_indices[i]]]--;
  1676. }
  1677. if( !priors )
  1678. {
  1679. int L = 0, R = n1;
  1680. for( i = 0; i < m; i++ )
  1681. rsum2 += (double)rc[i]*rc[i];
  1682. for( i = 0; i < n1 - 1; i++ )
  1683. {
  1684. int idx = responses[sorted_indices[i]];
  1685. int lv, rv;
  1686. L++; R--;
  1687. lv = lc[idx]; rv = rc[idx];
  1688. lsum2 += lv*2 + 1;
  1689. rsum2 -= rv*2 - 1;
  1690. lc[idx] = lv + 1; rc[idx] = rv - 1;
  1691. if( values[i] + epsilon < values[i+1] )
  1692. {
  1693. double val = (lsum2*R + rsum2*L)/((double)L*R);
  1694. if( best_val < val )
  1695. {
  1696. best_val = val;
  1697. best_i = i;
  1698. }
  1699. }
  1700. }
  1701. }
  1702. else
  1703. {
  1704. double L = 0, R = 0;
  1705. for( i = 0; i < m; i++ )
  1706. {
  1707. double wv = rc[i]*priors[i];
  1708. R += wv;
  1709. rsum2 += wv*wv;
  1710. }
  1711. for( i = 0; i < n1 - 1; i++ )
  1712. {
  1713. int idx = responses[sorted_indices[i]];
  1714. int lv, rv;
  1715. double p = priors[idx], p2 = p*p;
  1716. L += p; R -= p;
  1717. lv = lc[idx]; rv = rc[idx];
  1718. lsum2 += p2*(lv*2 + 1);
  1719. rsum2 -= p2*(rv*2 - 1);
  1720. lc[idx] = lv + 1; rc[idx] = rv - 1;
  1721. if( values[i] + epsilon < values[i+1] )
  1722. {
  1723. double val = (lsum2*R + rsum2*L)/((double)L*R);
  1724. if( best_val < val )
  1725. {
  1726. best_val = val;
  1727. best_i = i;
  1728. }
  1729. }
  1730. }
  1731. }
  1732. CvDTreeSplit* split = 0;
  1733. if( best_i >= 0 )
  1734. {
  1735. split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
  1736. split->var_idx = vi;
  1737. split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
  1738. split->ord.split_point = best_i;
  1739. split->inversed = 0;
  1740. split->quality = (float)best_val;
  1741. }
  1742. return split;
  1743. }
  1744. void CvDTree::cluster_categories( const int* vectors, int n, int m,
  1745. int* csums, int k, int* labels )
  1746. {
  1747. // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
  1748. int iters = 0, max_iters = 100;
  1749. int i, j, idx;
  1750. cv::AutoBuffer<double> buf(n + k);
  1751. double *v_weights = buf.data(), *c_weights = buf.data() + n;
  1752. bool modified = true;
  1753. RNG* r = data->rng;
  1754. // assign labels randomly
  1755. for( i = 0; i < n; i++ )
  1756. {
  1757. int sum = 0;
  1758. const int* v = vectors + i*m;
  1759. labels[i] = i < k ? i : r->uniform(0, k);
  1760. // compute weight of each vector
  1761. for( j = 0; j < m; j++ )
  1762. sum += v[j];
  1763. v_weights[i] = sum ? 1./sum : 0.;
  1764. }
  1765. for( i = 0; i < n; i++ )
  1766. {
  1767. int i1 = (*r)(n);
  1768. int i2 = (*r)(n);
  1769. CV_SWAP( labels[i1], labels[i2], j );
  1770. }
  1771. for( iters = 0; iters <= max_iters; iters++ )
  1772. {
  1773. // calculate csums
  1774. for( i = 0; i < k; i++ )
  1775. {
  1776. for( j = 0; j < m; j++ )
  1777. csums[i*m + j] = 0;
  1778. }
  1779. for( i = 0; i < n; i++ )
  1780. {
  1781. const int* v = vectors + i*m;
  1782. int* s = csums + labels[i]*m;
  1783. for( j = 0; j < m; j++ )
  1784. s[j] += v[j];
  1785. }
  1786. // exit the loop here, when we have up-to-date csums
  1787. if( iters == max_iters || !modified )
  1788. break;
  1789. modified = false;
  1790. // calculate weight of each cluster
  1791. for( i = 0; i < k; i++ )
  1792. {
  1793. const int* s = csums + i*m;
  1794. int sum = 0;
  1795. for( j = 0; j < m; j++ )
  1796. sum += s[j];
  1797. c_weights[i] = sum ? 1./sum : 0;
  1798. }
  1799. // now for each vector determine the closest cluster
  1800. for( i = 0; i < n; i++ )
  1801. {
  1802. const int* v = vectors + i*m;
  1803. double alpha = v_weights[i];
  1804. double min_dist2 = DBL_MAX;
  1805. int min_idx = -1;
  1806. for( idx = 0; idx < k; idx++ )
  1807. {
  1808. const int* s = csums + idx*m;
  1809. double dist2 = 0., beta = c_weights[idx];
  1810. for( j = 0; j < m; j++ )
  1811. {
  1812. double t = v[j]*alpha - s[j]*beta;
  1813. dist2 += t*t;
  1814. }
  1815. if( min_dist2 > dist2 )
  1816. {
  1817. min_dist2 = dist2;
  1818. min_idx = idx;
  1819. }
  1820. }
  1821. if( min_idx != labels[i] )
  1822. modified = true;
  1823. labels[i] = min_idx;
  1824. }
  1825. }
  1826. }
  1827. CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality,
  1828. CvDTreeSplit* _split, uchar* _ext_buf )
  1829. {
  1830. int ci = data->get_var_type(vi);
  1831. int n = node->sample_count;
  1832. int m = data->get_num_classes();
  1833. int _mi = data->cat_count->data.i[ci], mi = _mi;
  1834. int base_size = m*(3 + mi)*sizeof(int) + (mi+1)*sizeof(double);
  1835. if( m > 2 && mi > data->params.max_categories )
  1836. base_size += (m*std::min(data->params.max_categories, n) + mi)*sizeof(int);
  1837. else
  1838. base_size += mi*sizeof(int*);
  1839. cv::AutoBuffer<uchar> inn_buf(base_size);
  1840. if( !_ext_buf )
  1841. inn_buf.allocate(base_size + 2*n*sizeof(int));
  1842. uchar* base_buf = inn_buf.data();
  1843. uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
  1844. int* lc = (int*)base_buf;
  1845. int* rc = lc + m;
  1846. int* _cjk = rc + m*2, *cjk = _cjk;
  1847. double* c_weights = (double*)alignPtr(cjk + m*mi, sizeof(double));
  1848. int* labels_buf = (int*)ext_buf;
  1849. const int* labels = data->get_cat_var_data(node, vi, labels_buf);
  1850. int* responses_buf = labels_buf + n;
  1851. const int* responses = data->get_class_labels(node, responses_buf);
  1852. int* cluster_labels = 0;
  1853. int** int_ptr = 0;
  1854. int i, j, k, idx;
  1855. double L = 0, R = 0;
  1856. double best_val = init_quality;
  1857. int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
  1858. const double* priors = data->priors_mult->data.db;
  1859. // init array of counters:
  1860. // c_{jk} - number of samples that have vi-th input variable = j and response = k.
  1861. for( j = -1; j < mi; j++ )
  1862. for( k = 0; k < m; k++ )
  1863. cjk[j*m + k] = 0;
  1864. for( i = 0; i < n; i++ )
  1865. {
  1866. j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];
  1867. k = responses[i];
  1868. cjk[j*m + k]++;
  1869. }
  1870. if( m > 2 )
  1871. {
  1872. if( mi > data->params.max_categories )
  1873. {
  1874. mi = MIN(data->params.max_categories, n);
  1875. cjk = (int*)(c_weights + _mi);
  1876. cluster_labels = cjk + m*mi;
  1877. cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
  1878. }
  1879. subset_i = 1;
  1880. subset_n = 1 << mi;
  1881. }
  1882. else
  1883. {
  1884. assert( m == 2 );
  1885. int_ptr = (int**)(c_weights + _mi);
  1886. for( j = 0; j < mi; j++ )
  1887. int_ptr[j] = cjk + j*2 + 1;
  1888. std::sort(int_ptr, int_ptr + mi, LessThanPtr<int>());
  1889. subset_i = 0;
  1890. subset_n = mi;
  1891. }
  1892. for( k = 0; k < m; k++ )
  1893. {
  1894. int sum = 0;
  1895. for( j = 0; j < mi; j++ )
  1896. sum += cjk[j*m + k];
  1897. rc[k] = sum;
  1898. lc[k] = 0;
  1899. }
  1900. for( j = 0; j < mi; j++ )
  1901. {
  1902. double sum = 0;
  1903. for( k = 0; k < m; k++ )
  1904. sum += cjk[j*m + k]*priors[k];
  1905. c_weights[j] = sum;
  1906. R += c_weights[j];
  1907. }
  1908. for( ; subset_i < subset_n; subset_i++ )
  1909. {
  1910. double weight;
  1911. int* crow;
  1912. double lsum2 = 0, rsum2 = 0;
  1913. if( m == 2 )
  1914. idx = (int)(int_ptr[subset_i] - cjk)/2;
  1915. else
  1916. {
  1917. int graycode = (subset_i>>1)^subset_i;
  1918. int diff = graycode ^ prevcode;
  1919. // determine index of the changed bit.
  1920. Cv32suf u;
  1921. idx = diff >= (1 << 16) ? 16 : 0;
  1922. u.f = (float)(((diff >> 16) | diff) & 65535);
  1923. idx += (u.i >> 23) - 127;
  1924. subtract = graycode < prevcode;
  1925. prevcode = graycode;
  1926. }
  1927. crow = cjk + idx*m;
  1928. weight = c_weights[idx];
  1929. if( weight < FLT_EPSILON )
  1930. continue;
  1931. if( !subtract )
  1932. {
  1933. for( k = 0; k < m; k++ )
  1934. {
  1935. int t = crow[k];
  1936. int lval = lc[k] + t;
  1937. int rval = rc[k] - t;
  1938. double p = priors[k], p2 = p*p;
  1939. lsum2 += p2*lval*lval;
  1940. rsum2 += p2*rval*rval;
  1941. lc[k] = lval; rc[k] = rval;
  1942. }
  1943. L += weight;
  1944. R -= weight;
  1945. }
  1946. else
  1947. {
  1948. for( k = 0; k < m; k++ )
  1949. {
  1950. int t = crow[k];
  1951. int lval = lc[k] - t;
  1952. int rval = rc[k] + t;
  1953. double p = priors[k], p2 = p*p;
  1954. lsum2 += p2*lval*lval;
  1955. rsum2 += p2*rval*rval;
  1956. lc[k] = lval; rc[k] = rval;
  1957. }
  1958. L -= weight;
  1959. R += weight;
  1960. }
  1961. if( L > FLT_EPSILON && R > FLT_EPSILON )
  1962. {
  1963. double val = (lsum2*R + rsum2*L)/((double)L*R);
  1964. if( best_val < val )
  1965. {
  1966. best_val = val;
  1967. best_subset = subset_i;
  1968. }
  1969. }
  1970. }
  1971. CvDTreeSplit* split = 0;
  1972. if( best_subset >= 0 )
  1973. {
  1974. split = _split ? _split : data->new_split_cat( 0, -1.0f );
  1975. split->var_idx = vi;
  1976. split->quality = (float)best_val;
  1977. memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
  1978. if( m == 2 )
  1979. {
  1980. for( i = 0; i <= best_subset; i++ )
  1981. {
  1982. idx = (int)(int_ptr[i] - cjk) >> 1;
  1983. split->subset[idx >> 5] |= 1 << (idx & 31);
  1984. }
  1985. }
  1986. else
  1987. {
  1988. for( i = 0; i < _mi; i++ )
  1989. {
  1990. idx = cluster_labels ? cluster_labels[i] : i;
  1991. if( best_subset & (1 << idx) )
  1992. split->subset[i >> 5] |= 1 << (i & 31);
  1993. }
  1994. }
  1995. }
  1996. return split;
  1997. }
  1998. CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
  1999. {
  2000. const float epsilon = FLT_EPSILON*2;
  2001. int n = node->sample_count;
  2002. int n1 = node->get_num_valid(vi);
  2003. cv::AutoBuffer<uchar> inn_buf;
  2004. if( !_ext_buf )
  2005. inn_buf.allocate(2*n*(sizeof(int) + sizeof(float)));
  2006. uchar* ext_buf = _ext_buf ? _ext_buf : inn_buf.data();
  2007. float* values_buf = (float*)ext_buf;
  2008. int* sorted_indices_buf = (int*)(values_buf + n);
  2009. int* sample_indices_buf = sorted_indices_buf + n;
  2010. const float* values = 0;
  2011. const int* sorted_indices = 0;
  2012. data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
  2013. float* responses_buf = (float*)(sample_indices_buf + n);
  2014. const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
  2015. int i, best_i = -1;
  2016. double best_val = init_quality, lsum = 0, rsum = node->value*n;
  2017. int L = 0, R = n1;
  2018. // compensate for missing values
  2019. for( i = n1; i < n; i++ )
  2020. rsum -= responses[sorted_indices[i]];
  2021. // find the optimal split
  2022. for( i = 0; i < n1 - 1; i++ )
  2023. {
  2024. float t = responses[sorted_indices[i]];
  2025. L++; R--;
  2026. lsum += t;
  2027. rsum -= t;
  2028. if( values[i] + epsilon < values[i+1] )
  2029. {
  2030. double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
  2031. if( best_val < val )
  2032. {
  2033. best_val = val;
  2034. best_i = i;
  2035. }
  2036. }
  2037. }
  2038. CvDTreeSplit* split = 0;
  2039. if( best_i >= 0 )
  2040. {
  2041. split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
  2042. split->var_idx = vi;
  2043. split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
  2044. split->ord.split_point = best_i;
  2045. split->inversed = 0;
  2046. split->quality = (float)best_val;
  2047. }
  2048. return split;
  2049. }
  2050. CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
  2051. {
  2052. int ci = data->get_var_type(vi);
  2053. int n = node->sample_count;
  2054. int mi = data->cat_count->data.i[ci];
  2055. int base_size = (mi+2)*sizeof(double) + (mi+1)*(sizeof(int) + sizeof(double*));
  2056. cv::AutoBuffer<uchar> inn_buf(base_size);
  2057. if( !_ext_buf )
  2058. inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
  2059. uchar* base_buf = inn_buf.data();
  2060. uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
  2061. int* labels_buf = (int*)ext_buf;
  2062. const int* labels = data->get_cat_var_data(node, vi, labels_buf);
  2063. float* responses_buf = (float*)(labels_buf + n);
  2064. int* sample_indices_buf = (int*)(responses_buf + n);
  2065. const float* responses = data->get_ord_responses(node, responses_buf, sample_indices_buf);
  2066. double* sum = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
  2067. int* counts = (int*)(sum + mi) + 1;
  2068. double** sum_ptr = (double**)(counts + mi);
  2069. int i, L = 0, R = 0;
  2070. double best_val = init_quality, lsum = 0, rsum = 0;
  2071. int best_subset = -1, subset_i;
  2072. for( i = -1; i < mi; i++ )
  2073. sum[i] = counts[i] = 0;
  2074. // calculate sum response and weight of each category of the input var
  2075. for( i = 0; i < n; i++ )
  2076. {
  2077. int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];
  2078. double s = sum[idx] + responses[i];
  2079. int nc = counts[idx] + 1;
  2080. sum[idx] = s;
  2081. counts[idx] = nc;
  2082. }
  2083. // calculate average response in each category
  2084. for( i = 0; i < mi; i++ )
  2085. {
  2086. R += counts[i];
  2087. rsum += sum[i];
  2088. sum[i] /= MAX(counts[i],1);
  2089. sum_ptr[i] = sum + i;
  2090. }
  2091. std::sort(sum_ptr, sum_ptr + mi, LessThanPtr<double>());
  2092. // revert back to unnormalized sums
  2093. // (there should be a very little loss of accuracy)
  2094. for( i = 0; i < mi; i++ )
  2095. sum[i] *= counts[i];
  2096. for( subset_i = 0; subset_i < mi-1; subset_i++ )
  2097. {
  2098. int idx = (int)(sum_ptr[subset_i] - sum);
  2099. int ni = counts[idx];
  2100. if( ni )
  2101. {
  2102. double s = sum[idx];
  2103. lsum += s; L += ni;
  2104. rsum -= s; R -= ni;
  2105. if( L && R )
  2106. {
  2107. double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
  2108. if( best_val < val )
  2109. {
  2110. best_val = val;
  2111. best_subset = subset_i;
  2112. }
  2113. }
  2114. }
  2115. }
  2116. CvDTreeSplit* split = 0;
  2117. if( best_subset >= 0 )
  2118. {
  2119. split = _split ? _split : data->new_split_cat( 0, -1.0f);
  2120. split->var_idx = vi;
  2121. split->quality = (float)best_val;
  2122. memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
  2123. for( i = 0; i <= best_subset; i++ )
  2124. {
  2125. int idx = (int)(sum_ptr[i] - sum);
  2126. split->subset[idx >> 5] |= 1 << (idx & 31);
  2127. }
  2128. }
  2129. return split;
  2130. }
  2131. CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi, uchar* _ext_buf )
  2132. {
  2133. const float epsilon = FLT_EPSILON*2;
  2134. const char* dir = (char*)data->direction->data.ptr;
  2135. int n = node->sample_count, n1 = node->get_num_valid(vi);
  2136. cv::AutoBuffer<uchar> inn_buf;
  2137. if( !_ext_buf )
  2138. inn_buf.allocate( n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)) );
  2139. uchar* ext_buf = _ext_buf ? _ext_buf : inn_buf.data();
  2140. float* values_buf = (float*)ext_buf;
  2141. int* sorted_indices_buf = (int*)(values_buf + n);
  2142. int* sample_indices_buf = sorted_indices_buf + n;
  2143. const float* values = 0;
  2144. const int* sorted_indices = 0;
  2145. data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
  2146. // LL - number of samples that both the primary and the surrogate splits send to the left
  2147. // LR - ... primary split sends to the left and the surrogate split sends to the right
  2148. // RL - ... primary split sends to the right and the surrogate split sends to the left
  2149. // RR - ... both send to the right
  2150. int i, best_i = -1, best_inversed = 0;
  2151. double best_val;
  2152. if( !data->have_priors )
  2153. {
  2154. int LL = 0, RL = 0, LR, RR;
  2155. int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
  2156. int sum = 0, sum_abs = 0;
  2157. for( i = 0; i < n1; i++ )
  2158. {
  2159. int d = dir[sorted_indices[i]];
  2160. sum += d; sum_abs += d & 1;
  2161. }
  2162. // sum_abs = R + L; sum = R - L
  2163. RR = (sum_abs + sum) >> 1;
  2164. LR = (sum_abs - sum) >> 1;
  2165. // initially all the samples are sent to the right by the surrogate split,
  2166. // LR of them are sent to the left by primary split, and RR - to the right.
  2167. // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
  2168. for( i = 0; i < n1 - 1; i++ )
  2169. {
  2170. int d = dir[sorted_indices[i]];
  2171. if( d < 0 )
  2172. {
  2173. LL++; LR--;
  2174. if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )
  2175. {
  2176. best_val = LL + RR;
  2177. best_i = i; best_inversed = 0;
  2178. }
  2179. }
  2180. else if( d > 0 )
  2181. {
  2182. RL++; RR--;
  2183. if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )
  2184. {
  2185. best_val = RL + LR;
  2186. best_i = i; best_inversed = 1;
  2187. }
  2188. }
  2189. }
  2190. best_val = _best_val;
  2191. }
  2192. else
  2193. {
  2194. double LL = 0, RL = 0, LR, RR;
  2195. double worst_val = node->maxlr;
  2196. double sum = 0, sum_abs = 0;
  2197. const double* priors = data->priors_mult->data.db;
  2198. int* responses_buf = sample_indices_buf + n;
  2199. const int* responses = data->get_class_labels(node, responses_buf);
  2200. best_val = worst_val;
  2201. for( i = 0; i < n1; i++ )
  2202. {
  2203. int idx = sorted_indices[i];
  2204. double w = priors[responses[idx]];
  2205. int d = dir[idx];
  2206. sum += d*w; sum_abs += (d & 1)*w;
  2207. }
  2208. // sum_abs = R + L; sum = R - L
  2209. RR = (sum_abs + sum)*0.5;
  2210. LR = (sum_abs - sum)*0.5;
  2211. // initially all the samples are sent to the right by the surrogate split,
  2212. // LR of them are sent to the left by primary split, and RR - to the right.
  2213. // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
  2214. for( i = 0; i < n1 - 1; i++ )
  2215. {
  2216. int idx = sorted_indices[i];
  2217. double w = priors[responses[idx]];
  2218. int d = dir[idx];
  2219. if( d < 0 )
  2220. {
  2221. LL += w; LR -= w;
  2222. if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
  2223. {
  2224. best_val = LL + RR;
  2225. best_i = i; best_inversed = 0;
  2226. }
  2227. }
  2228. else if( d > 0 )
  2229. {
  2230. RL += w; RR -= w;
  2231. if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
  2232. {
  2233. best_val = RL + LR;
  2234. best_i = i; best_inversed = 1;
  2235. }
  2236. }
  2237. }
  2238. }
  2239. return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
  2240. (values[best_i] + values[best_i+1])*0.5f, best_i, best_inversed, (float)best_val ) : 0;
  2241. }
  2242. CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi, uchar* _ext_buf )
  2243. {
  2244. const char* dir = (char*)data->direction->data.ptr;
  2245. int n = node->sample_count;
  2246. int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
  2247. int base_size = (2*(mi+1)+1)*sizeof(double) + (!data->have_priors ? 2*(mi+1)*sizeof(int) : 0);
  2248. cv::AutoBuffer<uchar> inn_buf(base_size);
  2249. if( !_ext_buf )
  2250. inn_buf.allocate(base_size + n*(sizeof(int) + (data->have_priors ? sizeof(int) : 0)));
  2251. uchar* base_buf = inn_buf.data();
  2252. uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
  2253. int* labels_buf = (int*)ext_buf;
  2254. const int* labels = data->get_cat_var_data(node, vi, labels_buf);
  2255. // LL - number of samples that both the primary and the surrogate splits send to the left
  2256. // LR - ... primary split sends to the left and the surrogate split sends to the right
  2257. // RL - ... primary split sends to the right and the surrogate split sends to the left
  2258. // RR - ... both send to the right
  2259. CvDTreeSplit* split = data->new_split_cat( vi, 0 );
  2260. double best_val = 0;
  2261. double* lc = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
  2262. double* rc = lc + mi + 1;
  2263. for( i = -1; i < mi; i++ )
  2264. lc[i] = rc[i] = 0;
  2265. // for each category calculate the weight of samples
  2266. // sent to the left (lc) and to the right (rc) by the primary split
  2267. if( !data->have_priors )
  2268. {
  2269. int* _lc = (int*)rc + 1;
  2270. int* _rc = _lc + mi + 1;
  2271. for( i = -1; i < mi; i++ )
  2272. _lc[i] = _rc[i] = 0;
  2273. for( i = 0; i < n; i++ )
  2274. {
  2275. int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
  2276. int d = dir[i];
  2277. int sum = _lc[idx] + d;
  2278. int sum_abs = _rc[idx] + (d & 1);
  2279. _lc[idx] = sum; _rc[idx] = sum_abs;
  2280. }
  2281. for( i = 0; i < mi; i++ )
  2282. {
  2283. int sum = _lc[i];
  2284. int sum_abs = _rc[i];
  2285. lc[i] = (sum_abs - sum) >> 1;
  2286. rc[i] = (sum_abs + sum) >> 1;
  2287. }
  2288. }
  2289. else
  2290. {
  2291. const double* priors = data->priors_mult->data.db;
  2292. int* responses_buf = labels_buf + n;
  2293. const int* responses = data->get_class_labels(node, responses_buf);
  2294. for( i = 0; i < n; i++ )
  2295. {
  2296. int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
  2297. double w = priors[responses[i]];
  2298. int d = dir[i];
  2299. double sum = lc[idx] + d*w;
  2300. double sum_abs = rc[idx] + (d & 1)*w;
  2301. lc[idx] = sum; rc[idx] = sum_abs;
  2302. }
  2303. for( i = 0; i < mi; i++ )
  2304. {
  2305. double sum = lc[i];
  2306. double sum_abs = rc[i];
  2307. lc[i] = (sum_abs - sum) * 0.5;
  2308. rc[i] = (sum_abs + sum) * 0.5;
  2309. }
  2310. }
  2311. // 2. now form the split.
  2312. // in each category send all the samples to the same direction as majority
  2313. for( i = 0; i < mi; i++ )
  2314. {
  2315. double lval = lc[i], rval = rc[i];
  2316. if( lval > rval )
  2317. {
  2318. split->subset[i >> 5] |= 1 << (i & 31);
  2319. best_val += lval;
  2320. l_win++;
  2321. }
  2322. else
  2323. best_val += rval;
  2324. }
  2325. split->quality = (float)best_val;
  2326. if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
  2327. cvSetRemoveByPtr( data->split_heap, split ), split = 0;
  2328. return split;
  2329. }
  2330. void CvDTree::calc_node_value( CvDTreeNode* node )
  2331. {
  2332. int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
  2333. int m = data->get_num_classes();
  2334. int base_size = data->is_classifier ? m*cv_n*sizeof(int) : 2*cv_n*sizeof(double)+cv_n*sizeof(int);
  2335. int ext_size = n*(sizeof(int) + (data->is_classifier ? sizeof(int) : sizeof(int)+sizeof(float)));
  2336. cv::AutoBuffer<uchar> inn_buf(base_size + ext_size);
  2337. uchar* base_buf = inn_buf.data();
  2338. uchar* ext_buf = base_buf + base_size;
  2339. int* cv_labels_buf = (int*)ext_buf;
  2340. const int* cv_labels = data->get_cv_labels(node, cv_labels_buf);
  2341. if( data->is_classifier )
  2342. {
  2343. // in case of classification tree:
  2344. // * node value is the label of the class that has the largest weight in the node.
  2345. // * node risk is the weighted number of misclassified samples,
  2346. // * j-th cross-validation fold value and risk are calculated as above,
  2347. // but using the samples with cv_labels(*)!=j.
  2348. // * j-th cross-validation fold error is calculated as the weighted number of
  2349. // misclassified samples with cv_labels(*)==j.
  2350. // compute the number of instances of each class
  2351. int* cls_count = data->counts->data.i;
  2352. int* responses_buf = cv_labels_buf + n;
  2353. const int* responses = data->get_class_labels(node, responses_buf);
  2354. int* cv_cls_count = (int*)base_buf;
  2355. double max_val = -1, total_weight = 0;
  2356. int max_k = -1;
  2357. double* priors = data->priors_mult->data.db;
  2358. for( k = 0; k < m; k++ )
  2359. cls_count[k] = 0;
  2360. if( cv_n == 0 )
  2361. {
  2362. for( i = 0; i < n; i++ )
  2363. cls_count[responses[i]]++;
  2364. }
  2365. else
  2366. {
  2367. for( j = 0; j < cv_n; j++ )
  2368. for( k = 0; k < m; k++ )
  2369. cv_cls_count[j*m + k] = 0;
  2370. for( i = 0; i < n; i++ )
  2371. {
  2372. j = cv_labels[i]; k = responses[i];
  2373. cv_cls_count[j*m + k]++;
  2374. }
  2375. for( j = 0; j < cv_n; j++ )
  2376. for( k = 0; k < m; k++ )
  2377. cls_count[k] += cv_cls_count[j*m + k];
  2378. }
  2379. if( data->have_priors && node->parent == 0 )
  2380. {
  2381. // compute priors_mult from priors, take the sample ratio into account.
  2382. double sum = 0;
  2383. for( k = 0; k < m; k++ )
  2384. {
  2385. int n_k = cls_count[k];
  2386. priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
  2387. sum += priors[k];
  2388. }
  2389. sum = 1./sum;
  2390. for( k = 0; k < m; k++ )
  2391. priors[k] *= sum;
  2392. }
  2393. for( k = 0; k < m; k++ )
  2394. {
  2395. double val = cls_count[k]*priors[k];
  2396. total_weight += val;
  2397. if( max_val < val )
  2398. {
  2399. max_val = val;
  2400. max_k = k;
  2401. }
  2402. }
  2403. node->class_idx = max_k;
  2404. node->value = data->cat_map->data.i[
  2405. data->cat_ofs->data.i[data->cat_var_count] + max_k];
  2406. node->node_risk = total_weight - max_val;
  2407. for( j = 0; j < cv_n; j++ )
  2408. {
  2409. double sum_k = 0, sum = 0, max_val_k = 0;
  2410. max_val = -1; max_k = -1;
  2411. for( k = 0; k < m; k++ )
  2412. {
  2413. double w = priors[k];
  2414. double val_k = cv_cls_count[j*m + k]*w;
  2415. double val = cls_count[k]*w - val_k;
  2416. sum_k += val_k;
  2417. sum += val;
  2418. if( max_val < val )
  2419. {
  2420. max_val = val;
  2421. max_val_k = val_k;
  2422. max_k = k;
  2423. }
  2424. }
  2425. node->cv_Tn[j] = INT_MAX;
  2426. node->cv_node_risk[j] = sum - max_val;
  2427. node->cv_node_error[j] = sum_k - max_val_k;
  2428. }
  2429. }
  2430. else
  2431. {
  2432. // in case of regression tree:
  2433. // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
  2434. // n is the number of samples in the node.
  2435. // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
  2436. // * j-th cross-validation fold value and risk are calculated as above,
  2437. // but using the samples with cv_labels(*)!=j.
  2438. // * j-th cross-validation fold error is calculated
  2439. // using samples with cv_labels(*)==j as the test subset:
  2440. // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
  2441. // where node_value_j is the node value calculated
  2442. // as described in the previous bullet, and summation is done
  2443. // over the samples with cv_labels(*)==j.
  2444. double sum = 0, sum2 = 0;
  2445. float* values_buf = (float*)(cv_labels_buf + n);
  2446. int* sample_indices_buf = (int*)(values_buf + n);
  2447. const float* values = data->get_ord_responses(node, values_buf, sample_indices_buf);
  2448. double *cv_sum = 0, *cv_sum2 = 0;
  2449. int* cv_count = 0;
  2450. if( cv_n == 0 )
  2451. {
  2452. for( i = 0; i < n; i++ )
  2453. {
  2454. double t = values[i];
  2455. sum += t;
  2456. sum2 += t*t;
  2457. }
  2458. }
  2459. else
  2460. {
  2461. cv_sum = (double*)base_buf;
  2462. cv_sum2 = cv_sum + cv_n;
  2463. cv_count = (int*)(cv_sum2 + cv_n);
  2464. for( j = 0; j < cv_n; j++ )
  2465. {
  2466. cv_sum[j] = cv_sum2[j] = 0.;
  2467. cv_count[j] = 0;
  2468. }
  2469. for( i = 0; i < n; i++ )
  2470. {
  2471. j = cv_labels[i];
  2472. double t = values[i];
  2473. double s = cv_sum[j] + t;
  2474. double s2 = cv_sum2[j] + t*t;
  2475. int nc = cv_count[j] + 1;
  2476. cv_sum[j] = s;
  2477. cv_sum2[j] = s2;
  2478. cv_count[j] = nc;
  2479. }
  2480. for( j = 0; j < cv_n; j++ )
  2481. {
  2482. sum += cv_sum[j];
  2483. sum2 += cv_sum2[j];
  2484. }
  2485. }
  2486. node->node_risk = sum2 - (sum/n)*sum;
  2487. node->value = sum/n;
  2488. for( j = 0; j < cv_n; j++ )
  2489. {
  2490. double s = cv_sum[j], si = sum - s;
  2491. double s2 = cv_sum2[j], s2i = sum2 - s2;
  2492. int c = cv_count[j], ci = n - c;
  2493. double r = si/MAX(ci,1);
  2494. node->cv_node_risk[j] = s2i - r*r*ci;
  2495. node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
  2496. node->cv_Tn[j] = INT_MAX;
  2497. }
  2498. }
  2499. }
  2500. void CvDTree::complete_node_dir( CvDTreeNode* node )
  2501. {
  2502. int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
  2503. int nz = n - node->get_num_valid(node->split->var_idx);
  2504. char* dir = (char*)data->direction->data.ptr;
  2505. // try to complete direction using surrogate splits
  2506. if( nz && data->params.use_surrogates )
  2507. {
  2508. cv::AutoBuffer<uchar> inn_buf(n*(2*sizeof(int)+sizeof(float)));
  2509. CvDTreeSplit* split = node->split->next;
  2510. for( ; split != 0 && nz; split = split->next )
  2511. {
  2512. int inversed_mask = split->inversed ? -1 : 0;
  2513. vi = split->var_idx;
  2514. if( data->get_var_type(vi) >= 0 ) // split on categorical var
  2515. {
  2516. int* labels_buf = (int*)inn_buf.data();
  2517. const int* labels = data->get_cat_var_data(node, vi, labels_buf);
  2518. const int* subset = split->subset;
  2519. for( i = 0; i < n; i++ )
  2520. {
  2521. int idx = labels[i];
  2522. if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
  2523. {
  2524. int d = CV_DTREE_CAT_DIR(idx,subset);
  2525. dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
  2526. if( --nz )
  2527. break;
  2528. }
  2529. }
  2530. }
  2531. else // split on ordered var
  2532. {
  2533. float* values_buf = (float*)inn_buf.data();
  2534. int* sorted_indices_buf = (int*)(values_buf + n);
  2535. int* sample_indices_buf = sorted_indices_buf + n;
  2536. const float* values = 0;
  2537. const int* sorted_indices = 0;
  2538. data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
  2539. int split_point = split->ord.split_point;
  2540. int n1 = node->get_num_valid(vi);
  2541. assert( 0 <= split_point && split_point < n-1 );
  2542. for( i = 0; i < n1; i++ )
  2543. {
  2544. int idx = sorted_indices[i];
  2545. if( !dir[idx] )
  2546. {
  2547. int d = i <= split_point ? -1 : 1;
  2548. dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
  2549. if( --nz )
  2550. break;
  2551. }
  2552. }
  2553. }
  2554. }
  2555. }
  2556. // find the default direction for the rest
  2557. if( nz )
  2558. {
  2559. for( i = nr = 0; i < n; i++ )
  2560. nr += dir[i] > 0;
  2561. nl = n - nr - nz;
  2562. d0 = nl > nr ? -1 : nr > nl;
  2563. }
  2564. // make sure that every sample is directed either to the left or to the right
  2565. for( i = 0; i < n; i++ )
  2566. {
  2567. int d = dir[i];
  2568. if( !d )
  2569. {
  2570. d = d0;
  2571. if( !d )
  2572. d = d1, d1 = -d1;
  2573. }
  2574. d = d > 0;
  2575. dir[i] = (char)d; // remap (-1,1) to (0,1)
  2576. }
  2577. }
  2578. void CvDTree::split_node_data( CvDTreeNode* node )
  2579. {
  2580. int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
  2581. char* dir = (char*)data->direction->data.ptr;
  2582. CvDTreeNode *left = 0, *right = 0;
  2583. int* new_idx = data->split_buf->data.i;
  2584. int new_buf_idx = data->get_child_buf_idx( node );
  2585. int work_var_count = data->get_work_var_count();
  2586. CvMat* buf = data->buf;
  2587. size_t length_buf_row = data->get_length_subbuf();
  2588. cv::AutoBuffer<uchar> inn_buf(n*(3*sizeof(int) + sizeof(float)));
  2589. int* temp_buf = (int*)inn_buf.data();
  2590. complete_node_dir(node);
  2591. for( i = nl = nr = 0; i < n; i++ )
  2592. {
  2593. int d = dir[i];
  2594. // initialize new indices for splitting ordered variables
  2595. new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
  2596. nr += d;
  2597. nl += d^1;
  2598. }
  2599. bool split_input_data;
  2600. node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
  2601. node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
  2602. split_input_data = node->depth + 1 < data->params.max_depth &&
  2603. (node->left->sample_count > data->params.min_sample_count ||
  2604. node->right->sample_count > data->params.min_sample_count);
  2605. // split ordered variables, keep both halves sorted.
  2606. for( vi = 0; vi < data->var_count; vi++ )
  2607. {
  2608. int ci = data->get_var_type(vi);
  2609. if( ci >= 0 || !split_input_data )
  2610. continue;
  2611. int n1 = node->get_num_valid(vi);
  2612. float* src_val_buf = (float*)(uchar*)(temp_buf + n);
  2613. int* src_sorted_idx_buf = (int*)(src_val_buf + n);
  2614. int* src_sample_idx_buf = src_sorted_idx_buf + n;
  2615. const float* src_val = 0;
  2616. const int* src_sorted_idx = 0;
  2617. data->get_ord_var_data(node, vi, src_val_buf, src_sorted_idx_buf, &src_val, &src_sorted_idx, src_sample_idx_buf);
  2618. for(i = 0; i < n; i++)
  2619. temp_buf[i] = src_sorted_idx[i];
  2620. if (data->is_buf_16u)
  2621. {
  2622. unsigned short *ldst, *rdst, *ldst0, *rdst0;
  2623. //unsigned short tl, tr;
  2624. ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
  2625. vi*scount + left->offset);
  2626. rdst0 = rdst = (unsigned short*)(ldst + nl);
  2627. // split sorted
  2628. for( i = 0; i < n1; i++ )
  2629. {
  2630. int idx = temp_buf[i];
  2631. int d = dir[idx];
  2632. idx = new_idx[idx];
  2633. if (d)
  2634. {
  2635. *rdst = (unsigned short)idx;
  2636. rdst++;
  2637. }
  2638. else
  2639. {
  2640. *ldst = (unsigned short)idx;
  2641. ldst++;
  2642. }
  2643. }
  2644. left->set_num_valid(vi, (int)(ldst - ldst0));
  2645. right->set_num_valid(vi, (int)(rdst - rdst0));
  2646. // split missing
  2647. for( ; i < n; i++ )
  2648. {
  2649. int idx = temp_buf[i];
  2650. int d = dir[idx];
  2651. idx = new_idx[idx];
  2652. if (d)
  2653. {
  2654. *rdst = (unsigned short)idx;
  2655. rdst++;
  2656. }
  2657. else
  2658. {
  2659. *ldst = (unsigned short)idx;
  2660. ldst++;
  2661. }
  2662. }
  2663. }
  2664. else
  2665. {
  2666. int *ldst0, *ldst, *rdst0, *rdst;
  2667. ldst0 = ldst = buf->data.i + left->buf_idx*length_buf_row +
  2668. vi*scount + left->offset;
  2669. rdst0 = rdst = buf->data.i + right->buf_idx*length_buf_row +
  2670. vi*scount + right->offset;
  2671. // split sorted
  2672. for( i = 0; i < n1; i++ )
  2673. {
  2674. int idx = temp_buf[i];
  2675. int d = dir[idx];
  2676. idx = new_idx[idx];
  2677. if (d)
  2678. {
  2679. *rdst = idx;
  2680. rdst++;
  2681. }
  2682. else
  2683. {
  2684. *ldst = idx;
  2685. ldst++;
  2686. }
  2687. }
  2688. left->set_num_valid(vi, (int)(ldst - ldst0));
  2689. right->set_num_valid(vi, (int)(rdst - rdst0));
  2690. // split missing
  2691. for( ; i < n; i++ )
  2692. {
  2693. int idx = temp_buf[i];
  2694. int d = dir[idx];
  2695. idx = new_idx[idx];
  2696. if (d)
  2697. {
  2698. *rdst = idx;
  2699. rdst++;
  2700. }
  2701. else
  2702. {
  2703. *ldst = idx;
  2704. ldst++;
  2705. }
  2706. }
  2707. }
  2708. }
  2709. // split categorical vars, responses and cv_labels using new_idx relocation table
  2710. for( vi = 0; vi < work_var_count; vi++ )
  2711. {
  2712. int ci = data->get_var_type(vi);
  2713. int n1 = node->get_num_valid(vi), nr1 = 0;
  2714. if( ci < 0 || (vi < data->var_count && !split_input_data) )
  2715. continue;
  2716. int *src_lbls_buf = temp_buf + n;
  2717. const int* src_lbls = data->get_cat_var_data(node, vi, src_lbls_buf);
  2718. for(i = 0; i < n; i++)
  2719. temp_buf[i] = src_lbls[i];
  2720. if (data->is_buf_16u)
  2721. {
  2722. unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*length_buf_row +
  2723. vi*scount + left->offset);
  2724. unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*length_buf_row +
  2725. vi*scount + right->offset);
  2726. for( i = 0; i < n; i++ )
  2727. {
  2728. int d = dir[i];
  2729. int idx = temp_buf[i];
  2730. if (d)
  2731. {
  2732. *rdst = (unsigned short)idx;
  2733. rdst++;
  2734. nr1 += (idx != 65535 )&d;
  2735. }
  2736. else
  2737. {
  2738. *ldst = (unsigned short)idx;
  2739. ldst++;
  2740. }
  2741. }
  2742. if( vi < data->var_count )
  2743. {
  2744. left->set_num_valid(vi, n1 - nr1);
  2745. right->set_num_valid(vi, nr1);
  2746. }
  2747. }
  2748. else
  2749. {
  2750. int *ldst = buf->data.i + left->buf_idx*length_buf_row +
  2751. vi*scount + left->offset;
  2752. int *rdst = buf->data.i + right->buf_idx*length_buf_row +
  2753. vi*scount + right->offset;
  2754. for( i = 0; i < n; i++ )
  2755. {
  2756. int d = dir[i];
  2757. int idx = temp_buf[i];
  2758. if (d)
  2759. {
  2760. *rdst = idx;
  2761. rdst++;
  2762. nr1 += (idx >= 0)&d;
  2763. }
  2764. else
  2765. {
  2766. *ldst = idx;
  2767. ldst++;
  2768. }
  2769. }
  2770. if( vi < data->var_count )
  2771. {
  2772. left->set_num_valid(vi, n1 - nr1);
  2773. right->set_num_valid(vi, nr1);
  2774. }
  2775. }
  2776. }
  2777. // split sample indices
  2778. int *sample_idx_src_buf = temp_buf + n;
  2779. const int* sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);
  2780. for(i = 0; i < n; i++)
  2781. temp_buf[i] = sample_idx_src[i];
  2782. int pos = data->get_work_var_count();
  2783. if (data->is_buf_16u)
  2784. {
  2785. unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
  2786. pos*scount + left->offset);
  2787. unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*length_buf_row +
  2788. pos*scount + right->offset);
  2789. for (i = 0; i < n; i++)
  2790. {
  2791. int d = dir[i];
  2792. unsigned short idx = (unsigned short)temp_buf[i];
  2793. if (d)
  2794. {
  2795. *rdst = idx;
  2796. rdst++;
  2797. }
  2798. else
  2799. {
  2800. *ldst = idx;
  2801. ldst++;
  2802. }
  2803. }
  2804. }
  2805. else
  2806. {
  2807. int* ldst = buf->data.i + left->buf_idx*length_buf_row +
  2808. pos*scount + left->offset;
  2809. int* rdst = buf->data.i + right->buf_idx*length_buf_row +
  2810. pos*scount + right->offset;
  2811. for (i = 0; i < n; i++)
  2812. {
  2813. int d = dir[i];
  2814. int idx = temp_buf[i];
  2815. if (d)
  2816. {
  2817. *rdst = idx;
  2818. rdst++;
  2819. }
  2820. else
  2821. {
  2822. *ldst = idx;
  2823. ldst++;
  2824. }
  2825. }
  2826. }
  2827. // deallocate the parent node data that is not needed anymore
  2828. data->free_node_data(node);
  2829. }
  2830. float CvDTree::calc_error( CvMLData* _data, int type, std::vector<float> *resp )
  2831. {
  2832. float err = 0;
  2833. const CvMat* values = _data->get_values();
  2834. const CvMat* response = _data->get_responses();
  2835. const CvMat* missing = _data->get_missing();
  2836. const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
  2837. const CvMat* var_types = _data->get_var_types();
  2838. int* sidx = sample_idx ? sample_idx->data.i : 0;
  2839. int r_step = CV_IS_MAT_CONT(response->type) ?
  2840. 1 : response->step / CV_ELEM_SIZE(response->type);
  2841. bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
  2842. int sample_count = sample_idx ? sample_idx->cols : 0;
  2843. sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
  2844. float* pred_resp = 0;
  2845. if( resp && (sample_count > 0) )
  2846. {
  2847. resp->resize( sample_count );
  2848. pred_resp = &((*resp)[0]);
  2849. }
  2850. if ( is_classifier )
  2851. {
  2852. for( int i = 0; i < sample_count; i++ )
  2853. {
  2854. CvMat sample, miss;
  2855. int si = sidx ? sidx[i] : i;
  2856. cvGetRow( values, &sample, si );
  2857. if( missing )
  2858. cvGetRow( missing, &miss, si );
  2859. float r = (float)predict( &sample, missing ? &miss : 0 )->value;
  2860. if( pred_resp )
  2861. pred_resp[i] = r;
  2862. int d = fabs((double)r - response->data.fl[(size_t)si*r_step]) <= FLT_EPSILON ? 0 : 1;
  2863. err += d;
  2864. }
  2865. err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
  2866. }
  2867. else
  2868. {
  2869. for( int i = 0; i < sample_count; i++ )
  2870. {
  2871. CvMat sample, miss;
  2872. int si = sidx ? sidx[i] : i;
  2873. cvGetRow( values, &sample, si );
  2874. if( missing )
  2875. cvGetRow( missing, &miss, si );
  2876. float r = (float)predict( &sample, missing ? &miss : 0 )->value;
  2877. if( pred_resp )
  2878. pred_resp[i] = r;
  2879. float d = r - response->data.fl[(size_t)si*r_step];
  2880. err += d*d;
  2881. }
  2882. err = sample_count ? err / (float)sample_count : -FLT_MAX;
  2883. }
  2884. return err;
  2885. }
  2886. void CvDTree::prune_cv()
  2887. {
  2888. CvMat* ab = 0;
  2889. CvMat* temp = 0;
  2890. CvMat* err_jk = 0;
  2891. // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
  2892. // 2. choose the best tree index (if need, apply 1SE rule).
  2893. // 3. store the best index and cut the branches.
  2894. CV_FUNCNAME( "CvDTree::prune_cv" );
  2895. __BEGIN__;
  2896. int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
  2897. // currently, 1SE for regression is not implemented
  2898. bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
  2899. double* err;
  2900. double min_err = 0, min_err_se = 0;
  2901. int min_idx = -1;
  2902. CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
  2903. // build the main tree sequence, calculate alpha's
  2904. for(;;tree_count++)
  2905. {
  2906. double min_alpha = update_tree_rnc(tree_count, -1);
  2907. if( cut_tree(tree_count, -1, min_alpha) )
  2908. break;
  2909. if( ab->cols <= tree_count )
  2910. {
  2911. CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
  2912. for( ti = 0; ti < ab->cols; ti++ )
  2913. temp->data.db[ti] = ab->data.db[ti];
  2914. cvReleaseMat( &ab );
  2915. ab = temp;
  2916. temp = 0;
  2917. }
  2918. ab->data.db[tree_count] = min_alpha;
  2919. }
  2920. ab->data.db[0] = 0.;
  2921. if( tree_count > 0 )
  2922. {
  2923. for( ti = 1; ti < tree_count-1; ti++ )
  2924. ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
  2925. ab->data.db[tree_count-1] = DBL_MAX*0.5;
  2926. CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
  2927. err = err_jk->data.db;
  2928. for( j = 0; j < cv_n; j++ )
  2929. {
  2930. int tj = 0, tk = 0;
  2931. for( ; tk < tree_count; tj++ )
  2932. {
  2933. double min_alpha = update_tree_rnc(tj, j);
  2934. if( cut_tree(tj, j, min_alpha) )
  2935. min_alpha = DBL_MAX;
  2936. for( ; tk < tree_count; tk++ )
  2937. {
  2938. if( ab->data.db[tk] > min_alpha )
  2939. break;
  2940. err[j*tree_count + tk] = root->tree_error;
  2941. }
  2942. }
  2943. }
  2944. for( ti = 0; ti < tree_count; ti++ )
  2945. {
  2946. double sum_err = 0;
  2947. for( j = 0; j < cv_n; j++ )
  2948. sum_err += err[j*tree_count + ti];
  2949. if( ti == 0 || sum_err < min_err )
  2950. {
  2951. min_err = sum_err;
  2952. min_idx = ti;
  2953. if( use_1se )
  2954. min_err_se = sqrt( sum_err*(n - sum_err) );
  2955. }
  2956. else if( sum_err < min_err + min_err_se )
  2957. min_idx = ti;
  2958. }
  2959. }
  2960. pruned_tree_idx = min_idx;
  2961. free_prune_data(data->params.truncate_pruned_tree != 0);
  2962. __END__;
  2963. cvReleaseMat( &err_jk );
  2964. cvReleaseMat( &ab );
  2965. cvReleaseMat( &temp );
  2966. }
  2967. double CvDTree::update_tree_rnc( int T, int fold )
  2968. {
  2969. CvDTreeNode* node = root;
  2970. double min_alpha = DBL_MAX;
  2971. for(;;)
  2972. {
  2973. CvDTreeNode* parent;
  2974. for(;;)
  2975. {
  2976. int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
  2977. if( t <= T || !node->left )
  2978. {
  2979. node->complexity = 1;
  2980. node->tree_risk = node->node_risk;
  2981. node->tree_error = 0.;
  2982. if( fold >= 0 )
  2983. {
  2984. node->tree_risk = node->cv_node_risk[fold];
  2985. node->tree_error = node->cv_node_error[fold];
  2986. }
  2987. break;
  2988. }
  2989. node = node->left;
  2990. }
  2991. for( parent = node->parent; parent && parent->right == node;
  2992. node = parent, parent = parent->parent )
  2993. {
  2994. parent->complexity += node->complexity;
  2995. parent->tree_risk += node->tree_risk;
  2996. parent->tree_error += node->tree_error;
  2997. parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
  2998. - parent->tree_risk)/(parent->complexity - 1);
  2999. min_alpha = MIN( min_alpha, parent->alpha );
  3000. }
  3001. if( !parent )
  3002. break;
  3003. parent->complexity = node->complexity;
  3004. parent->tree_risk = node->tree_risk;
  3005. parent->tree_error = node->tree_error;
  3006. node = parent->right;
  3007. }
  3008. return min_alpha;
  3009. }
  3010. int CvDTree::cut_tree( int T, int fold, double min_alpha )
  3011. {
  3012. CvDTreeNode* node = root;
  3013. if( !node->left )
  3014. return 1;
  3015. for(;;)
  3016. {
  3017. CvDTreeNode* parent;
  3018. for(;;)
  3019. {
  3020. int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
  3021. if( t <= T || !node->left )
  3022. break;
  3023. if( node->alpha <= min_alpha + FLT_EPSILON )
  3024. {
  3025. if( fold >= 0 )
  3026. node->cv_Tn[fold] = T;
  3027. else
  3028. node->Tn = T;
  3029. if( node == root )
  3030. return 1;
  3031. break;
  3032. }
  3033. node = node->left;
  3034. }
  3035. for( parent = node->parent; parent && parent->right == node;
  3036. node = parent, parent = parent->parent )
  3037. ;
  3038. if( !parent )
  3039. break;
  3040. node = parent->right;
  3041. }
  3042. return 0;
  3043. }
  3044. void CvDTree::free_prune_data(bool _cut_tree)
  3045. {
  3046. CvDTreeNode* node = root;
  3047. for(;;)
  3048. {
  3049. CvDTreeNode* parent;
  3050. for(;;)
  3051. {
  3052. // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
  3053. // as we will clear the whole cross-validation heap at the end
  3054. node->cv_Tn = 0;
  3055. node->cv_node_error = node->cv_node_risk = 0;
  3056. if( !node->left )
  3057. break;
  3058. node = node->left;
  3059. }
  3060. for( parent = node->parent; parent && parent->right == node;
  3061. node = parent, parent = parent->parent )
  3062. {
  3063. if( _cut_tree && parent->Tn <= pruned_tree_idx )
  3064. {
  3065. data->free_node( parent->left );
  3066. data->free_node( parent->right );
  3067. parent->left = parent->right = 0;
  3068. }
  3069. }
  3070. if( !parent )
  3071. break;
  3072. node = parent->right;
  3073. }
  3074. if( data->cv_heap )
  3075. cvClearSet( data->cv_heap );
  3076. }
  3077. void CvDTree::free_tree()
  3078. {
  3079. if( root && data && data->shared )
  3080. {
  3081. pruned_tree_idx = INT_MIN;
  3082. free_prune_data(true);
  3083. data->free_node(root);
  3084. root = 0;
  3085. }
  3086. }
  3087. CvDTreeNode* CvDTree::predict( const CvMat* _sample,
  3088. const CvMat* _missing, bool preprocessed_input ) const
  3089. {
  3090. cv::AutoBuffer<int> catbuf;
  3091. int i, mstep = 0;
  3092. const uchar* m = 0;
  3093. CvDTreeNode* node = root;
  3094. if( !node )
  3095. CV_Error( CV_StsError, "The tree has not been trained yet" );
  3096. if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
  3097. (_sample->cols != 1 && _sample->rows != 1) ||
  3098. (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||
  3099. (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )
  3100. CV_Error( CV_StsBadArg,
  3101. "the input sample must be 1d floating-point vector with the same "
  3102. "number of elements as the total number of variables used for training" );
  3103. const float* sample = _sample->data.fl;
  3104. int step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
  3105. if( data->cat_count && !preprocessed_input ) // cache for categorical variables
  3106. {
  3107. int n = data->cat_count->cols;
  3108. catbuf.allocate(n);
  3109. for( i = 0; i < n; i++ )
  3110. catbuf[i] = -1;
  3111. }
  3112. if( _missing )
  3113. {
  3114. if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
  3115. !CV_ARE_SIZES_EQ(_missing, _sample) )
  3116. CV_Error( CV_StsBadArg,
  3117. "the missing data mask must be 8-bit vector of the same size as input sample" );
  3118. m = _missing->data.ptr;
  3119. mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
  3120. }
  3121. const int* vtype = data->var_type->data.i;
  3122. const int* vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
  3123. const int* cmap = data->cat_map ? data->cat_map->data.i : 0;
  3124. const int* cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
  3125. while( node->Tn > pruned_tree_idx && node->left )
  3126. {
  3127. CvDTreeSplit* split = node->split;
  3128. int dir = 0;
  3129. for( ; !dir && split != 0; split = split->next )
  3130. {
  3131. int vi = split->var_idx;
  3132. int ci = vtype[vi];
  3133. i = vidx ? vidx[vi] : vi;
  3134. float val = sample[(size_t)i*step];
  3135. if( m && m[(size_t)i*mstep] )
  3136. continue;
  3137. if( ci < 0 ) // ordered
  3138. dir = val <= split->ord.c ? -1 : 1;
  3139. else // categorical
  3140. {
  3141. int c;
  3142. if( preprocessed_input )
  3143. c = cvRound(val);
  3144. else
  3145. {
  3146. c = catbuf[ci];
  3147. if( c < 0 )
  3148. {
  3149. int a = c = cofs[ci];
  3150. int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];
  3151. int ival = cvRound(val);
  3152. if( ival != val )
  3153. CV_Error( CV_StsBadArg,
  3154. "one of input categorical variable is not an integer" );
  3155. int sh = 0;
  3156. while( a < b )
  3157. {
  3158. sh++;
  3159. c = (a + b) >> 1;
  3160. if( ival < cmap[c] )
  3161. b = c;
  3162. else if( ival > cmap[c] )
  3163. a = c+1;
  3164. else
  3165. break;
  3166. }
  3167. if( c < 0 || ival != cmap[c] )
  3168. continue;
  3169. catbuf[ci] = c -= cofs[ci];
  3170. }
  3171. }
  3172. c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;
  3173. dir = CV_DTREE_CAT_DIR(c, split->subset);
  3174. }
  3175. if( split->inversed )
  3176. dir = -dir;
  3177. }
  3178. if( !dir )
  3179. {
  3180. double diff = node->right->sample_count - node->left->sample_count;
  3181. dir = diff < 0 ? -1 : 1;
  3182. }
  3183. node = dir < 0 ? node->left : node->right;
  3184. }
  3185. return node;
  3186. }
  3187. CvDTreeNode* CvDTree::predict( const Mat& _sample, const Mat& _missing, bool preprocessed_input ) const
  3188. {
  3189. CvMat sample = cvMat(_sample), mmask = cvMat(_missing);
  3190. return predict(&sample, mmask.data.ptr ? &mmask : 0, preprocessed_input);
  3191. }
  3192. const CvMat* CvDTree::get_var_importance()
  3193. {
  3194. if( !var_importance )
  3195. {
  3196. CvDTreeNode* node = root;
  3197. double* importance;
  3198. if( !node )
  3199. return 0;
  3200. var_importance = cvCreateMat( 1, data->var_count, CV_64F );
  3201. cvZero( var_importance );
  3202. importance = var_importance->data.db;
  3203. for(;;)
  3204. {
  3205. CvDTreeNode* parent;
  3206. for( ;; node = node->left )
  3207. {
  3208. CvDTreeSplit* split = node->split;
  3209. if( !node->left || node->Tn <= pruned_tree_idx )
  3210. break;
  3211. for( ; split != 0; split = split->next )
  3212. importance[split->var_idx] += split->quality;
  3213. }
  3214. for( parent = node->parent; parent && parent->right == node;
  3215. node = parent, parent = parent->parent )
  3216. ;
  3217. if( !parent )
  3218. break;
  3219. node = parent->right;
  3220. }
  3221. cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
  3222. }
  3223. return var_importance;
  3224. }
  3225. void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split ) const
  3226. {
  3227. int ci;
  3228. cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
  3229. cvWriteInt( fs, "var", split->var_idx );
  3230. cvWriteReal( fs, "quality", split->quality );
  3231. ci = data->get_var_type(split->var_idx);
  3232. if( ci >= 0 ) // split on a categorical var
  3233. {
  3234. int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
  3235. for( i = 0; i < n; i++ )
  3236. to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
  3237. // ad-hoc rule when to use inverse categorical split notation
  3238. // to achieve more compact and clear representation
  3239. default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
  3240. cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
  3241. "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
  3242. for( i = 0; i < n; i++ )
  3243. {
  3244. int dir = CV_DTREE_CAT_DIR(i,split->subset);
  3245. if( dir*default_dir < 0 )
  3246. cvWriteInt( fs, 0, i );
  3247. }
  3248. cvEndWriteStruct( fs );
  3249. }
  3250. else
  3251. cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
  3252. cvEndWriteStruct( fs );
  3253. }
  3254. void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node ) const
  3255. {
  3256. CvDTreeSplit* split;
  3257. cvStartWriteStruct( fs, 0, CV_NODE_MAP );
  3258. cvWriteInt( fs, "depth", node->depth );
  3259. cvWriteInt( fs, "sample_count", node->sample_count );
  3260. cvWriteReal( fs, "value", node->value );
  3261. if( data->is_classifier )
  3262. cvWriteInt( fs, "norm_class_idx", node->class_idx );
  3263. cvWriteInt( fs, "Tn", node->Tn );
  3264. cvWriteInt( fs, "complexity", node->complexity );
  3265. cvWriteReal( fs, "alpha", node->alpha );
  3266. cvWriteReal( fs, "node_risk", node->node_risk );
  3267. cvWriteReal( fs, "tree_risk", node->tree_risk );
  3268. cvWriteReal( fs, "tree_error", node->tree_error );
  3269. if( node->left )
  3270. {
  3271. cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
  3272. for( split = node->split; split != 0; split = split->next )
  3273. write_split( fs, split );
  3274. cvEndWriteStruct( fs );
  3275. }
  3276. cvEndWriteStruct( fs );
  3277. }
  3278. void CvDTree::write_tree_nodes( CvFileStorage* fs ) const
  3279. {
  3280. //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
  3281. __BEGIN__;
  3282. CvDTreeNode* node = root;
  3283. // traverse the tree and save all the nodes in depth-first order
  3284. for(;;)
  3285. {
  3286. CvDTreeNode* parent;
  3287. for(;;)
  3288. {
  3289. write_node( fs, node );
  3290. if( !node->left )
  3291. break;
  3292. node = node->left;
  3293. }
  3294. for( parent = node->parent; parent && parent->right == node;
  3295. node = parent, parent = parent->parent )
  3296. ;
  3297. if( !parent )
  3298. break;
  3299. node = parent->right;
  3300. }
  3301. __END__;
  3302. }
  3303. void CvDTree::write( CvFileStorage* fs, const char* name ) const
  3304. {
  3305. //CV_FUNCNAME( "CvDTree::write" );
  3306. __BEGIN__;
  3307. cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
  3308. //get_var_importance();
  3309. data->write_params( fs );
  3310. //if( var_importance )
  3311. //cvWrite( fs, "var_importance", var_importance );
  3312. write( fs );
  3313. cvEndWriteStruct( fs );
  3314. __END__;
  3315. }
  3316. void CvDTree::write( CvFileStorage* fs ) const
  3317. {
  3318. //CV_FUNCNAME( "CvDTree::write" );
  3319. __BEGIN__;
  3320. cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
  3321. cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
  3322. write_tree_nodes( fs );
  3323. cvEndWriteStruct( fs );
  3324. __END__;
  3325. }
  3326. CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
  3327. {
  3328. CvDTreeSplit* split = 0;
  3329. CV_FUNCNAME( "CvDTree::read_split" );
  3330. __BEGIN__;
  3331. int vi, ci;
  3332. if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
  3333. CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
  3334. vi = cvReadIntByName( fs, fnode, "var", -1 );
  3335. if( (unsigned)vi >= (unsigned)data->var_count )
  3336. CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
  3337. ci = data->get_var_type(vi);
  3338. if( ci >= 0 ) // split on categorical var
  3339. {
  3340. int i, n = data->cat_count->data.i[ci], inversed = 0, val;
  3341. CvSeqReader reader;
  3342. CvFileNode* inseq;
  3343. split = data->new_split_cat( vi, 0 );
  3344. inseq = cvGetFileNodeByName( fs, fnode, "in" );
  3345. if( !inseq )
  3346. {
  3347. inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
  3348. inversed = 1;
  3349. }
  3350. if( !inseq ||
  3351. (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
  3352. CV_ERROR( CV_StsParseError,
  3353. "Either 'in' or 'not_in' tags should be inside a categorical split data" );
  3354. if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
  3355. {
  3356. val = inseq->data.i;
  3357. if( (unsigned)val >= (unsigned)n )
  3358. CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
  3359. split->subset[val >> 5] |= 1 << (val & 31);
  3360. }
  3361. else
  3362. {
  3363. cvStartReadSeq( inseq->data.seq, &reader );
  3364. for( i = 0; i < reader.seq->total; i++ )
  3365. {
  3366. CvFileNode* inode = (CvFileNode*)reader.ptr;
  3367. val = inode->data.i;
  3368. if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
  3369. CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
  3370. split->subset[val >> 5] |= 1 << (val & 31);
  3371. CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
  3372. }
  3373. }
  3374. // for categorical splits we do not use inversed splits,
  3375. // instead we inverse the variable set in the split
  3376. if( inversed )
  3377. for( i = 0; i < (n + 31) >> 5; i++ )
  3378. split->subset[i] ^= -1;
  3379. }
  3380. else
  3381. {
  3382. CvFileNode* cmp_node;
  3383. split = data->new_split_ord( vi, 0, 0, 0, 0 );
  3384. cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
  3385. if( !cmp_node )
  3386. {
  3387. cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
  3388. split->inversed = 1;
  3389. }
  3390. split->ord.c = (float)cvReadReal( cmp_node );
  3391. }
  3392. split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
  3393. __END__;
  3394. return split;
  3395. }
  3396. CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
  3397. {
  3398. CvDTreeNode* node = 0;
  3399. CV_FUNCNAME( "CvDTree::read_node" );
  3400. __BEGIN__;
  3401. CvFileNode* splits;
  3402. int i, depth;
  3403. if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
  3404. CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
  3405. CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
  3406. depth = cvReadIntByName( fs, fnode, "depth", -1 );
  3407. if( depth != node->depth )
  3408. CV_ERROR( CV_StsParseError, "incorrect node depth" );
  3409. node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
  3410. node->value = cvReadRealByName( fs, fnode, "value" );
  3411. if( data->is_classifier )
  3412. node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
  3413. node->Tn = cvReadIntByName( fs, fnode, "Tn" );
  3414. node->complexity = cvReadIntByName( fs, fnode, "complexity" );
  3415. node->alpha = cvReadRealByName( fs, fnode, "alpha" );
  3416. node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
  3417. node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
  3418. node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
  3419. splits = cvGetFileNodeByName( fs, fnode, "splits" );
  3420. if( splits )
  3421. {
  3422. CvSeqReader reader;
  3423. CvDTreeSplit* last_split = 0;
  3424. if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
  3425. CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
  3426. cvStartReadSeq( splits->data.seq, &reader );
  3427. for( i = 0; i < reader.seq->total; i++ )
  3428. {
  3429. CvDTreeSplit* split;
  3430. CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
  3431. if( !last_split )
  3432. node->split = last_split = split;
  3433. else
  3434. last_split = last_split->next = split;
  3435. CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
  3436. }
  3437. }
  3438. __END__;
  3439. return node;
  3440. }
  3441. void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
  3442. {
  3443. CV_FUNCNAME( "CvDTree::read_tree_nodes" );
  3444. __BEGIN__;
  3445. CvSeqReader reader;
  3446. CvDTreeNode _root;
  3447. CvDTreeNode* parent = &_root;
  3448. int i;
  3449. parent->left = parent->right = parent->parent = 0;
  3450. cvStartReadSeq( fnode->data.seq, &reader );
  3451. for( i = 0; i < reader.seq->total; i++ )
  3452. {
  3453. CvDTreeNode* node;
  3454. CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
  3455. if( !parent->left )
  3456. parent->left = node;
  3457. else
  3458. parent->right = node;
  3459. if( node->split )
  3460. parent = node;
  3461. else
  3462. {
  3463. while( parent && parent->right )
  3464. parent = parent->parent;
  3465. }
  3466. CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
  3467. }
  3468. root = _root.left;
  3469. __END__;
  3470. }
  3471. void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
  3472. {
  3473. CvDTreeTrainData* _data = new CvDTreeTrainData();
  3474. _data->read_params( fs, fnode );
  3475. read( fs, fnode, _data );
  3476. get_var_importance();
  3477. }
  3478. // a special entry point for reading weak decision trees from the tree ensembles
  3479. void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
  3480. {
  3481. CV_FUNCNAME( "CvDTree::read" );
  3482. __BEGIN__;
  3483. CvFileNode* tree_nodes;
  3484. clear();
  3485. data = _data;
  3486. tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
  3487. if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
  3488. CV_ERROR( CV_StsParseError, "nodes tag is missing" );
  3489. pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
  3490. read_tree_nodes( fs, tree_nodes );
  3491. __END__;
  3492. }
  3493. Mat CvDTree::getVarImportance()
  3494. {
  3495. return cvarrToMat(get_var_importance());
  3496. }
  3497. /* End of file. */