old_ml_boost.cpp 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162
  1. /*M///////////////////////////////////////////////////////////////////////////////////////
  2. //
  3. // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
  4. //
  5. // By downloading, copying, installing or using the software you agree to this license.
  6. // If you do not agree to this license, do not download, install,
  7. // copy or use the software.
  8. //
  9. //
  10. // Intel License Agreement
  11. //
  12. // Copyright (C) 2000, Intel Corporation, all rights reserved.
  13. // Third party copyrights are property of their respective owners.
  14. //
  15. // Redistribution and use in source and binary forms, with or without modification,
  16. // are permitted provided that the following conditions are met:
  17. //
  18. // * Redistribution's of source code must retain the above copyright notice,
  19. // this list of conditions and the following disclaimer.
  20. //
  21. // * Redistribution's in binary form must reproduce the above copyright notice,
  22. // this list of conditions and the following disclaimer in the documentation
  23. // and/or other materials provided with the distribution.
  24. //
  25. // * The name of Intel Corporation may not be used to endorse or promote products
  26. // derived from this software without specific prior written permission.
  27. //
  28. // This software is provided by the copyright holders and contributors "as is" and
  29. // any express or implied warranties, including, but not limited to, the implied
  30. // warranties of merchantability and fitness for a particular purpose are disclaimed.
  31. // In no event shall the Intel Corporation or contributors be liable for any direct,
  32. // indirect, incidental, special, exemplary, or consequential damages
  33. // (including, but not limited to, procurement of substitute goods or services;
  34. // loss of use, data, or profits; or business interruption) however caused
  35. // and on any theory of liability, whether in contract, strict liability,
  36. // or tort (including negligence or otherwise) arising in any way out of
  37. // the use of this software, even if advised of the possibility of such damage.
  38. //
  39. //M*/
  40. #include "old_ml_precomp.hpp"
  41. static inline double
  42. log_ratio( double val )
  43. {
  44. const double eps = 1e-5;
  45. val = MAX( val, eps );
  46. val = MIN( val, 1. - eps );
  47. return log( val/(1. - val) );
  48. }
  49. CvBoostParams::CvBoostParams()
  50. {
  51. boost_type = CvBoost::REAL;
  52. weak_count = 100;
  53. weight_trim_rate = 0.95;
  54. cv_folds = 0;
  55. max_depth = 1;
  56. }
  57. CvBoostParams::CvBoostParams( int _boost_type, int _weak_count,
  58. double _weight_trim_rate, int _max_depth,
  59. bool _use_surrogates, const float* _priors )
  60. {
  61. boost_type = _boost_type;
  62. weak_count = _weak_count;
  63. weight_trim_rate = _weight_trim_rate;
  64. split_criteria = CvBoost::DEFAULT;
  65. cv_folds = 0;
  66. max_depth = _max_depth;
  67. use_surrogates = _use_surrogates;
  68. priors = _priors;
  69. }
  70. ///////////////////////////////// CvBoostTree ///////////////////////////////////
  71. CvBoostTree::CvBoostTree()
  72. {
  73. ensemble = 0;
  74. }
  75. CvBoostTree::~CvBoostTree()
  76. {
  77. clear();
  78. }
  79. void
  80. CvBoostTree::clear()
  81. {
  82. CvDTree::clear();
  83. ensemble = 0;
  84. }
  85. bool
  86. CvBoostTree::train( CvDTreeTrainData* _train_data,
  87. const CvMat* _subsample_idx, CvBoost* _ensemble )
  88. {
  89. clear();
  90. ensemble = _ensemble;
  91. data = _train_data;
  92. data->shared = true;
  93. return do_train( _subsample_idx );
  94. }
  95. bool
  96. CvBoostTree::train( const CvMat*, int, const CvMat*, const CvMat*,
  97. const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
  98. {
  99. assert(0);
  100. return false;
  101. }
  102. bool
  103. CvBoostTree::train( CvDTreeTrainData*, const CvMat* )
  104. {
  105. assert(0);
  106. return false;
  107. }
  108. void
  109. CvBoostTree::scale( double _scale )
  110. {
  111. CvDTreeNode* node = root;
  112. // traverse the tree and scale all the node values
  113. for(;;)
  114. {
  115. CvDTreeNode* parent;
  116. for(;;)
  117. {
  118. node->value *= _scale;
  119. if( !node->left )
  120. break;
  121. node = node->left;
  122. }
  123. for( parent = node->parent; parent && parent->right == node;
  124. node = parent, parent = parent->parent )
  125. ;
  126. if( !parent )
  127. break;
  128. node = parent->right;
  129. }
  130. }
  131. void
  132. CvBoostTree::try_split_node( CvDTreeNode* node )
  133. {
  134. CvDTree::try_split_node( node );
  135. if( !node->left )
  136. {
  137. // if the node has not been split,
  138. // store the responses for the corresponding training samples
  139. double* weak_eval = ensemble->get_weak_response()->data.db;
  140. cv::AutoBuffer<int> inn_buf(node->sample_count);
  141. const int* labels = data->get_cv_labels(node, inn_buf.data());
  142. int i, count = node->sample_count;
  143. double value = node->value;
  144. for( i = 0; i < count; i++ )
  145. weak_eval[labels[i]] = value;
  146. }
  147. }
  148. double
  149. CvBoostTree::calc_node_dir( CvDTreeNode* node )
  150. {
  151. char* dir = (char*)data->direction->data.ptr;
  152. const double* weights = ensemble->get_subtree_weights()->data.db;
  153. int i, n = node->sample_count, vi = node->split->var_idx;
  154. double L, R;
  155. assert( !node->split->inversed );
  156. if( data->get_var_type(vi) >= 0 ) // split on categorical var
  157. {
  158. cv::AutoBuffer<int> inn_buf(n);
  159. const int* cat_labels = data->get_cat_var_data(node, vi, inn_buf.data());
  160. const int* subset = node->split->subset;
  161. double sum = 0, sum_abs = 0;
  162. for( i = 0; i < n; i++ )
  163. {
  164. int idx = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];
  165. double w = weights[i];
  166. int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
  167. sum += d*w; sum_abs += (d & 1)*w;
  168. dir[i] = (char)d;
  169. }
  170. R = (sum_abs + sum) * 0.5;
  171. L = (sum_abs - sum) * 0.5;
  172. }
  173. else // split on ordered var
  174. {
  175. cv::AutoBuffer<uchar> inn_buf(2*n*sizeof(int)+n*sizeof(float));
  176. float* values_buf = (float*)inn_buf.data();
  177. int* sorted_indices_buf = (int*)(values_buf + n);
  178. int* sample_indices_buf = sorted_indices_buf + n;
  179. const float* values = 0;
  180. const int* sorted_indices = 0;
  181. data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
  182. int split_point = node->split->ord.split_point;
  183. int n1 = node->get_num_valid(vi);
  184. assert( 0 <= split_point && split_point < n1-1 );
  185. L = R = 0;
  186. for( i = 0; i <= split_point; i++ )
  187. {
  188. int idx = sorted_indices[i];
  189. double w = weights[idx];
  190. dir[idx] = (char)-1;
  191. L += w;
  192. }
  193. for( ; i < n1; i++ )
  194. {
  195. int idx = sorted_indices[i];
  196. double w = weights[idx];
  197. dir[idx] = (char)1;
  198. R += w;
  199. }
  200. for( ; i < n; i++ )
  201. dir[sorted_indices[i]] = (char)0;
  202. }
  203. node->maxlr = MAX( L, R );
  204. return node->split->quality/(L + R);
  205. }
  206. CvDTreeSplit*
  207. CvBoostTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality,
  208. CvDTreeSplit* _split, uchar* _ext_buf )
  209. {
  210. const float epsilon = FLT_EPSILON*2;
  211. const double* weights = ensemble->get_subtree_weights()->data.db;
  212. int n = node->sample_count;
  213. int n1 = node->get_num_valid(vi);
  214. cv::AutoBuffer<uchar> inn_buf;
  215. if( !_ext_buf )
  216. inn_buf.allocate(n*(3*sizeof(int)+sizeof(float)));
  217. uchar* ext_buf = _ext_buf ? _ext_buf : inn_buf.data();
  218. float* values_buf = (float*)ext_buf;
  219. int* sorted_indices_buf = (int*)(values_buf + n);
  220. int* sample_indices_buf = sorted_indices_buf + n;
  221. const float* values = 0;
  222. const int* sorted_indices = 0;
  223. data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
  224. int* responses_buf = sorted_indices_buf + n;
  225. const int* responses = data->get_class_labels( node, responses_buf );
  226. const double* rcw0 = weights + n;
  227. double lcw[2] = {0,0}, rcw[2];
  228. int i, best_i = -1;
  229. double best_val = init_quality;
  230. int boost_type = ensemble->get_params().boost_type;
  231. int split_criteria = ensemble->get_params().split_criteria;
  232. rcw[0] = rcw0[0]; rcw[1] = rcw0[1];
  233. for( i = n1; i < n; i++ )
  234. {
  235. int idx = sorted_indices[i];
  236. double w = weights[idx];
  237. rcw[responses[idx]] -= w;
  238. }
  239. if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
  240. split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
  241. if( split_criteria == CvBoost::GINI )
  242. {
  243. double L = 0, R = rcw[0] + rcw[1];
  244. double lsum2 = 0, rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
  245. for( i = 0; i < n1 - 1; i++ )
  246. {
  247. int idx = sorted_indices[i];
  248. double w = weights[idx], w2 = w*w;
  249. double lv, rv;
  250. idx = responses[idx];
  251. L += w; R -= w;
  252. lv = lcw[idx]; rv = rcw[idx];
  253. lsum2 += 2*lv*w + w2;
  254. rsum2 -= 2*rv*w - w2;
  255. lcw[idx] = lv + w; rcw[idx] = rv - w;
  256. if( values[i] + epsilon < values[i+1] )
  257. {
  258. double val = (lsum2*R + rsum2*L)/(L*R);
  259. if( best_val < val )
  260. {
  261. best_val = val;
  262. best_i = i;
  263. }
  264. }
  265. }
  266. }
  267. else
  268. {
  269. for( i = 0; i < n1 - 1; i++ )
  270. {
  271. int idx = sorted_indices[i];
  272. double w = weights[idx];
  273. idx = responses[idx];
  274. lcw[idx] += w;
  275. rcw[idx] -= w;
  276. if( values[i] + epsilon < values[i+1] )
  277. {
  278. double val = lcw[0] + rcw[1], val2 = lcw[1] + rcw[0];
  279. val = MAX(val, val2);
  280. if( best_val < val )
  281. {
  282. best_val = val;
  283. best_i = i;
  284. }
  285. }
  286. }
  287. }
  288. CvDTreeSplit* split = 0;
  289. if( best_i >= 0 )
  290. {
  291. split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
  292. split->var_idx = vi;
  293. split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
  294. split->ord.split_point = best_i;
  295. split->inversed = 0;
  296. split->quality = (float)best_val;
  297. }
  298. return split;
  299. }
  300. template<typename T>
  301. class LessThanPtr
  302. {
  303. public:
  304. bool operator()(T* a, T* b) const { return *a < *b; }
  305. };
  306. CvDTreeSplit*
  307. CvBoostTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
  308. {
  309. int ci = data->get_var_type(vi);
  310. int n = node->sample_count;
  311. int mi = data->cat_count->data.i[ci];
  312. int base_size = (2*mi+3)*sizeof(double) + mi*sizeof(double*);
  313. cv::AutoBuffer<uchar> inn_buf((2*mi+3)*sizeof(double) + mi*sizeof(double*));
  314. if( !_ext_buf)
  315. inn_buf.allocate( base_size + 2*n*sizeof(int) );
  316. uchar* base_buf = inn_buf.data();
  317. uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
  318. int* cat_labels_buf = (int*)ext_buf;
  319. const int* cat_labels = data->get_cat_var_data(node, vi, cat_labels_buf);
  320. int* responses_buf = cat_labels_buf + n;
  321. const int* responses = data->get_class_labels(node, responses_buf);
  322. double lcw[2]={0,0}, rcw[2]={0,0};
  323. double* cjk = (double*)cv::alignPtr(base_buf,sizeof(double))+2;
  324. const double* weights = ensemble->get_subtree_weights()->data.db;
  325. double** dbl_ptr = (double**)(cjk + 2*mi);
  326. int i, j, k, idx;
  327. double L = 0, R;
  328. double best_val = init_quality;
  329. int best_subset = -1, subset_i;
  330. int boost_type = ensemble->get_params().boost_type;
  331. int split_criteria = ensemble->get_params().split_criteria;
  332. // init array of counters:
  333. // c_{jk} - number of samples that have vi-th input variable = j and response = k.
  334. for( j = -1; j < mi; j++ )
  335. cjk[j*2] = cjk[j*2+1] = 0;
  336. for( i = 0; i < n; i++ )
  337. {
  338. double w = weights[i];
  339. j = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];
  340. k = responses[i];
  341. cjk[j*2 + k] += w;
  342. }
  343. for( j = 0; j < mi; j++ )
  344. {
  345. rcw[0] += cjk[j*2];
  346. rcw[1] += cjk[j*2+1];
  347. dbl_ptr[j] = cjk + j*2 + 1;
  348. }
  349. R = rcw[0] + rcw[1];
  350. if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
  351. split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
  352. // sort rows of c_jk by increasing c_j,1
  353. // (i.e. by the weight of samples in j-th category that belong to class 1)
  354. std::sort(dbl_ptr, dbl_ptr + mi, LessThanPtr<double>());
  355. for( subset_i = 0; subset_i < mi-1; subset_i++ )
  356. {
  357. idx = (int)(dbl_ptr[subset_i] - cjk)/2;
  358. const double* crow = cjk + idx*2;
  359. double w0 = crow[0], w1 = crow[1];
  360. double weight = w0 + w1;
  361. if( weight < FLT_EPSILON )
  362. continue;
  363. lcw[0] += w0; rcw[0] -= w0;
  364. lcw[1] += w1; rcw[1] -= w1;
  365. if( split_criteria == CvBoost::GINI )
  366. {
  367. double lsum2 = lcw[0]*lcw[0] + lcw[1]*lcw[1];
  368. double rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
  369. L += weight;
  370. R -= weight;
  371. if( L > FLT_EPSILON && R > FLT_EPSILON )
  372. {
  373. double val = (lsum2*R + rsum2*L)/(L*R);
  374. if( best_val < val )
  375. {
  376. best_val = val;
  377. best_subset = subset_i;
  378. }
  379. }
  380. }
  381. else
  382. {
  383. double val = lcw[0] + rcw[1];
  384. double val2 = lcw[1] + rcw[0];
  385. val = MAX(val, val2);
  386. if( best_val < val )
  387. {
  388. best_val = val;
  389. best_subset = subset_i;
  390. }
  391. }
  392. }
  393. CvDTreeSplit* split = 0;
  394. if( best_subset >= 0 )
  395. {
  396. split = _split ? _split : data->new_split_cat( 0, -1.0f);
  397. split->var_idx = vi;
  398. split->quality = (float)best_val;
  399. memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
  400. for( i = 0; i <= best_subset; i++ )
  401. {
  402. idx = (int)(dbl_ptr[i] - cjk) >> 1;
  403. split->subset[idx >> 5] |= 1 << (idx & 31);
  404. }
  405. }
  406. return split;
  407. }
  408. CvDTreeSplit*
  409. CvBoostTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
  410. {
  411. const float epsilon = FLT_EPSILON*2;
  412. const double* weights = ensemble->get_subtree_weights()->data.db;
  413. int n = node->sample_count;
  414. int n1 = node->get_num_valid(vi);
  415. cv::AutoBuffer<uchar> inn_buf;
  416. if( !_ext_buf )
  417. inn_buf.allocate(2*n*(sizeof(int)+sizeof(float)));
  418. uchar* ext_buf = _ext_buf ? _ext_buf : inn_buf.data();
  419. float* values_buf = (float*)ext_buf;
  420. int* indices_buf = (int*)(values_buf + n);
  421. int* sample_indices_buf = indices_buf + n;
  422. const float* values = 0;
  423. const int* indices = 0;
  424. data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices, sample_indices_buf );
  425. float* responses_buf = (float*)(indices_buf + n);
  426. const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
  427. int i, best_i = -1;
  428. double L = 0, R = weights[n];
  429. double best_val = init_quality, lsum = 0, rsum = node->value*R;
  430. // compensate for missing values
  431. for( i = n1; i < n; i++ )
  432. {
  433. int idx = indices[i];
  434. double w = weights[idx];
  435. rsum -= responses[idx]*w;
  436. R -= w;
  437. }
  438. // find the optimal split
  439. for( i = 0; i < n1 - 1; i++ )
  440. {
  441. int idx = indices[i];
  442. double w = weights[idx];
  443. double t = responses[idx]*w;
  444. L += w; R -= w;
  445. lsum += t; rsum -= t;
  446. if( values[i] + epsilon < values[i+1] )
  447. {
  448. double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
  449. if( best_val < val )
  450. {
  451. best_val = val;
  452. best_i = i;
  453. }
  454. }
  455. }
  456. CvDTreeSplit* split = 0;
  457. if( best_i >= 0 )
  458. {
  459. split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
  460. split->var_idx = vi;
  461. split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
  462. split->ord.split_point = best_i;
  463. split->inversed = 0;
  464. split->quality = (float)best_val;
  465. }
  466. return split;
  467. }
  468. CvDTreeSplit*
  469. CvBoostTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
  470. {
  471. const double* weights = ensemble->get_subtree_weights()->data.db;
  472. int ci = data->get_var_type(vi);
  473. int n = node->sample_count;
  474. int mi = data->cat_count->data.i[ci];
  475. int base_size = (2*mi+3)*sizeof(double) + mi*sizeof(double*);
  476. cv::AutoBuffer<uchar> inn_buf(base_size);
  477. if( !_ext_buf )
  478. inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
  479. uchar* base_buf = inn_buf.data();
  480. uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
  481. int* cat_labels_buf = (int*)ext_buf;
  482. const int* cat_labels = data->get_cat_var_data(node, vi, cat_labels_buf);
  483. float* responses_buf = (float*)(cat_labels_buf + n);
  484. int* sample_indices_buf = (int*)(responses_buf + n);
  485. const float* responses = data->get_ord_responses(node, responses_buf, sample_indices_buf);
  486. double* sum = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
  487. double* counts = sum + mi + 1;
  488. double** sum_ptr = (double**)(counts + mi);
  489. double L = 0, R = 0, best_val = init_quality, lsum = 0, rsum = 0;
  490. int i, best_subset = -1, subset_i;
  491. for( i = -1; i < mi; i++ )
  492. sum[i] = counts[i] = 0;
  493. // calculate sum response and weight of each category of the input var
  494. for( i = 0; i < n; i++ )
  495. {
  496. int idx = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];
  497. double w = weights[i];
  498. double s = sum[idx] + responses[i]*w;
  499. double nc = counts[idx] + w;
  500. sum[idx] = s;
  501. counts[idx] = nc;
  502. }
  503. // calculate average response in each category
  504. for( i = 0; i < mi; i++ )
  505. {
  506. R += counts[i];
  507. rsum += sum[i];
  508. sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;
  509. sum_ptr[i] = sum + i;
  510. }
  511. std::sort(sum_ptr, sum_ptr + mi, LessThanPtr<double>());
  512. // revert back to unnormalized sums
  513. // (there should be a very little loss in accuracy)
  514. for( i = 0; i < mi; i++ )
  515. sum[i] *= counts[i];
  516. for( subset_i = 0; subset_i < mi-1; subset_i++ )
  517. {
  518. int idx = (int)(sum_ptr[subset_i] - sum);
  519. double ni = counts[idx];
  520. if( ni > FLT_EPSILON )
  521. {
  522. double s = sum[idx];
  523. lsum += s; L += ni;
  524. rsum -= s; R -= ni;
  525. if( L > FLT_EPSILON && R > FLT_EPSILON )
  526. {
  527. double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
  528. if( best_val < val )
  529. {
  530. best_val = val;
  531. best_subset = subset_i;
  532. }
  533. }
  534. }
  535. }
  536. CvDTreeSplit* split = 0;
  537. if( best_subset >= 0 )
  538. {
  539. split = _split ? _split : data->new_split_cat( 0, -1.0f);
  540. split->var_idx = vi;
  541. split->quality = (float)best_val;
  542. memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
  543. for( i = 0; i <= best_subset; i++ )
  544. {
  545. int idx = (int)(sum_ptr[i] - sum);
  546. split->subset[idx >> 5] |= 1 << (idx & 31);
  547. }
  548. }
  549. return split;
  550. }
  551. CvDTreeSplit*
  552. CvBoostTree::find_surrogate_split_ord( CvDTreeNode* node, int vi, uchar* _ext_buf )
  553. {
  554. const float epsilon = FLT_EPSILON*2;
  555. int n = node->sample_count;
  556. cv::AutoBuffer<uchar> inn_buf;
  557. if( !_ext_buf )
  558. inn_buf.allocate(n*(2*sizeof(int)+sizeof(float)));
  559. uchar* ext_buf = _ext_buf ? _ext_buf : inn_buf.data();
  560. float* values_buf = (float*)ext_buf;
  561. int* indices_buf = (int*)(values_buf + n);
  562. int* sample_indices_buf = indices_buf + n;
  563. const float* values = 0;
  564. const int* indices = 0;
  565. data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices, sample_indices_buf );
  566. const double* weights = ensemble->get_subtree_weights()->data.db;
  567. const char* dir = (char*)data->direction->data.ptr;
  568. int n1 = node->get_num_valid(vi);
  569. // LL - number of samples that both the primary and the surrogate splits send to the left
  570. // LR - ... primary split sends to the left and the surrogate split sends to the right
  571. // RL - ... primary split sends to the right and the surrogate split sends to the left
  572. // RR - ... both send to the right
  573. int i, best_i = -1, best_inversed = 0;
  574. double best_val;
  575. double LL = 0, RL = 0, LR, RR;
  576. double worst_val = node->maxlr;
  577. double sum = 0, sum_abs = 0;
  578. best_val = worst_val;
  579. for( i = 0; i < n1; i++ )
  580. {
  581. int idx = indices[i];
  582. double w = weights[idx];
  583. int d = dir[idx];
  584. sum += d*w; sum_abs += (d & 1)*w;
  585. }
  586. // sum_abs = R + L; sum = R - L
  587. RR = (sum_abs + sum)*0.5;
  588. LR = (sum_abs - sum)*0.5;
  589. // initially all the samples are sent to the right by the surrogate split,
  590. // LR of them are sent to the left by primary split, and RR - to the right.
  591. // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
  592. for( i = 0; i < n1 - 1; i++ )
  593. {
  594. int idx = indices[i];
  595. double w = weights[idx];
  596. int d = dir[idx];
  597. if( d < 0 )
  598. {
  599. LL += w; LR -= w;
  600. if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
  601. {
  602. best_val = LL + RR;
  603. best_i = i; best_inversed = 0;
  604. }
  605. }
  606. else if( d > 0 )
  607. {
  608. RL += w; RR -= w;
  609. if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
  610. {
  611. best_val = RL + LR;
  612. best_i = i; best_inversed = 1;
  613. }
  614. }
  615. }
  616. return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
  617. (values[best_i] + values[best_i+1])*0.5f, best_i,
  618. best_inversed, (float)best_val ) : 0;
  619. }
  620. CvDTreeSplit*
  621. CvBoostTree::find_surrogate_split_cat( CvDTreeNode* node, int vi, uchar* _ext_buf )
  622. {
  623. const char* dir = (char*)data->direction->data.ptr;
  624. const double* weights = ensemble->get_subtree_weights()->data.db;
  625. int n = node->sample_count;
  626. int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
  627. int base_size = (2*mi+3)*sizeof(double);
  628. cv::AutoBuffer<uchar> inn_buf(base_size);
  629. if( !_ext_buf )
  630. inn_buf.allocate(base_size + n*sizeof(int));
  631. uchar* ext_buf = _ext_buf ? _ext_buf : inn_buf.data();
  632. int* cat_labels_buf = (int*)ext_buf;
  633. const int* cat_labels = data->get_cat_var_data(node, vi, cat_labels_buf);
  634. // LL - number of samples that both the primary and the surrogate splits send to the left
  635. // LR - ... primary split sends to the left and the surrogate split sends to the right
  636. // RL - ... primary split sends to the right and the surrogate split sends to the left
  637. // RR - ... both send to the right
  638. CvDTreeSplit* split = data->new_split_cat( vi, 0 );
  639. double best_val = 0;
  640. double* lc = (double*)cv::alignPtr(cat_labels_buf + n, sizeof(double)) + 1;
  641. double* rc = lc + mi + 1;
  642. for( i = -1; i < mi; i++ )
  643. lc[i] = rc[i] = 0;
  644. // 1. for each category calculate the weight of samples
  645. // sent to the left (lc) and to the right (rc) by the primary split
  646. for( i = 0; i < n; i++ )
  647. {
  648. int idx = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];
  649. double w = weights[i];
  650. int d = dir[i];
  651. double sum = lc[idx] + d*w;
  652. double sum_abs = rc[idx] + (d & 1)*w;
  653. lc[idx] = sum; rc[idx] = sum_abs;
  654. }
  655. for( i = 0; i < mi; i++ )
  656. {
  657. double sum = lc[i];
  658. double sum_abs = rc[i];
  659. lc[i] = (sum_abs - sum) * 0.5;
  660. rc[i] = (sum_abs + sum) * 0.5;
  661. }
  662. // 2. now form the split.
  663. // in each category send all the samples to the same direction as majority
  664. for( i = 0; i < mi; i++ )
  665. {
  666. double lval = lc[i], rval = rc[i];
  667. if( lval > rval )
  668. {
  669. split->subset[i >> 5] |= 1 << (i & 31);
  670. best_val += lval;
  671. }
  672. else
  673. best_val += rval;
  674. }
  675. split->quality = (float)best_val;
  676. if( split->quality <= node->maxlr )
  677. cvSetRemoveByPtr( data->split_heap, split ), split = 0;
  678. return split;
  679. }
  680. void
  681. CvBoostTree::calc_node_value( CvDTreeNode* node )
  682. {
  683. int i, n = node->sample_count;
  684. const double* weights = ensemble->get_weights()->data.db;
  685. cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int) + ( data->is_classifier ? sizeof(int) : sizeof(int) + sizeof(float))));
  686. int* labels_buf = (int*)inn_buf.data();
  687. const int* labels = data->get_cv_labels(node, labels_buf);
  688. double* subtree_weights = ensemble->get_subtree_weights()->data.db;
  689. double rcw[2] = {0,0};
  690. int boost_type = ensemble->get_params().boost_type;
  691. if( data->is_classifier )
  692. {
  693. int* _responses_buf = labels_buf + n;
  694. const int* _responses = data->get_class_labels(node, _responses_buf);
  695. int m = data->get_num_classes();
  696. int* cls_count = data->counts->data.i;
  697. for( int k = 0; k < m; k++ )
  698. cls_count[k] = 0;
  699. for( i = 0; i < n; i++ )
  700. {
  701. int idx = labels[i];
  702. double w = weights[idx];
  703. int r = _responses[i];
  704. rcw[r] += w;
  705. cls_count[r]++;
  706. subtree_weights[i] = w;
  707. }
  708. node->class_idx = rcw[1] > rcw[0];
  709. if( boost_type == CvBoost::DISCRETE )
  710. {
  711. // ignore cat_map for responses, and use {-1,1},
  712. // as the whole ensemble response is computes as sign(sum_i(weak_response_i)
  713. node->value = node->class_idx*2 - 1;
  714. }
  715. else
  716. {
  717. double p = rcw[1]/(rcw[0] + rcw[1]);
  718. assert( boost_type == CvBoost::REAL );
  719. // store log-ratio of the probability
  720. node->value = 0.5*log_ratio(p);
  721. }
  722. }
  723. else
  724. {
  725. // in case of regression tree:
  726. // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
  727. // n is the number of samples in the node.
  728. // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
  729. double sum = 0, sum2 = 0, iw;
  730. float* values_buf = (float*)(labels_buf + n);
  731. int* sample_indices_buf = (int*)(values_buf + n);
  732. const float* values = data->get_ord_responses(node, values_buf, sample_indices_buf);
  733. for( i = 0; i < n; i++ )
  734. {
  735. int idx = labels[i];
  736. double w = weights[idx]/*priors[values[i] > 0]*/;
  737. double t = values[i];
  738. rcw[0] += w;
  739. subtree_weights[i] = w;
  740. sum += t*w;
  741. sum2 += t*t*w;
  742. }
  743. iw = 1./rcw[0];
  744. node->value = sum*iw;
  745. node->node_risk = sum2 - (sum*iw)*sum;
  746. // renormalize the risk, as in try_split_node the unweighted formula
  747. // sqrt(risk)/n is used, rather than sqrt(risk)/sum(weights_i)
  748. node->node_risk *= n*iw*n*iw;
  749. }
  750. // store summary weights
  751. subtree_weights[n] = rcw[0];
  752. subtree_weights[n+1] = rcw[1];
  753. }
  754. void CvBoostTree::read( CvFileStorage* fs, CvFileNode* fnode, CvBoost* _ensemble, CvDTreeTrainData* _data )
  755. {
  756. CvDTree::read( fs, fnode, _data );
  757. ensemble = _ensemble;
  758. }
  759. void CvBoostTree::read( CvFileStorage*, CvFileNode* )
  760. {
  761. assert(0);
  762. }
  763. void CvBoostTree::read( CvFileStorage* _fs, CvFileNode* _node,
  764. CvDTreeTrainData* _data )
  765. {
  766. CvDTree::read( _fs, _node, _data );
  767. }
  768. /////////////////////////////////// CvBoost /////////////////////////////////////
  769. CvBoost::CvBoost()
  770. {
  771. data = 0;
  772. weak = 0;
  773. default_model_name = "my_boost_tree";
  774. active_vars = active_vars_abs = orig_response = sum_response = weak_eval =
  775. subsample_mask = weights = subtree_weights = 0;
  776. have_active_cat_vars = have_subsample = false;
  777. clear();
  778. }
  779. void CvBoost::prune( CvSlice slice )
  780. {
  781. if( weak && weak->total > 0 )
  782. {
  783. CvSeqReader reader;
  784. int i, count = cvSliceLength( slice, weak );
  785. cvStartReadSeq( weak, &reader );
  786. cvSetSeqReaderPos( &reader, slice.start_index );
  787. for( i = 0; i < count; i++ )
  788. {
  789. CvBoostTree* w;
  790. CV_READ_SEQ_ELEM( w, reader );
  791. delete w;
  792. }
  793. cvSeqRemoveSlice( weak, slice );
  794. }
  795. }
  796. void CvBoost::clear()
  797. {
  798. if( weak )
  799. {
  800. prune( CV_WHOLE_SEQ );
  801. cvReleaseMemStorage( &weak->storage );
  802. }
  803. if( data )
  804. delete data;
  805. weak = 0;
  806. data = 0;
  807. cvReleaseMat( &active_vars );
  808. cvReleaseMat( &active_vars_abs );
  809. cvReleaseMat( &orig_response );
  810. cvReleaseMat( &sum_response );
  811. cvReleaseMat( &weak_eval );
  812. cvReleaseMat( &subsample_mask );
  813. cvReleaseMat( &weights );
  814. cvReleaseMat( &subtree_weights );
  815. have_subsample = false;
  816. }
  817. CvBoost::~CvBoost()
  818. {
  819. clear();
  820. }
  821. CvBoost::CvBoost( const CvMat* _train_data, int _tflag,
  822. const CvMat* _responses, const CvMat* _var_idx,
  823. const CvMat* _sample_idx, const CvMat* _var_type,
  824. const CvMat* _missing_mask, CvBoostParams _params )
  825. {
  826. weak = 0;
  827. data = 0;
  828. default_model_name = "my_boost_tree";
  829. active_vars = active_vars_abs = orig_response = sum_response = weak_eval =
  830. subsample_mask = weights = subtree_weights = 0;
  831. train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
  832. _var_type, _missing_mask, _params );
  833. }
  834. bool
  835. CvBoost::set_params( const CvBoostParams& _params )
  836. {
  837. bool ok = false;
  838. CV_FUNCNAME( "CvBoost::set_params" );
  839. __BEGIN__;
  840. params = _params;
  841. if( params.boost_type != DISCRETE && params.boost_type != REAL &&
  842. params.boost_type != LOGIT && params.boost_type != GENTLE )
  843. CV_ERROR( CV_StsBadArg, "Unknown/unsupported boosting type" );
  844. params.weak_count = MAX( params.weak_count, 1 );
  845. params.weight_trim_rate = MAX( params.weight_trim_rate, 0. );
  846. params.weight_trim_rate = MIN( params.weight_trim_rate, 1. );
  847. if( params.weight_trim_rate < FLT_EPSILON )
  848. params.weight_trim_rate = 1.f;
  849. if( params.boost_type == DISCRETE &&
  850. params.split_criteria != GINI && params.split_criteria != MISCLASS )
  851. params.split_criteria = MISCLASS;
  852. if( params.boost_type == REAL &&
  853. params.split_criteria != GINI && params.split_criteria != MISCLASS )
  854. params.split_criteria = GINI;
  855. if( (params.boost_type == LOGIT || params.boost_type == GENTLE) &&
  856. params.split_criteria != SQERR )
  857. params.split_criteria = SQERR;
  858. ok = true;
  859. __END__;
  860. return ok;
  861. }
  862. bool
  863. CvBoost::train( const CvMat* _train_data, int _tflag,
  864. const CvMat* _responses, const CvMat* _var_idx,
  865. const CvMat* _sample_idx, const CvMat* _var_type,
  866. const CvMat* _missing_mask,
  867. CvBoostParams _params, bool _update )
  868. {
  869. bool ok = false;
  870. CvMemStorage* storage = 0;
  871. CV_FUNCNAME( "CvBoost::train" );
  872. __BEGIN__;
  873. int i;
  874. set_params( _params );
  875. cvReleaseMat( &active_vars );
  876. cvReleaseMat( &active_vars_abs );
  877. if( !_update || !data )
  878. {
  879. clear();
  880. data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
  881. _sample_idx, _var_type, _missing_mask, _params, true, true );
  882. if( data->get_num_classes() != 2 )
  883. CV_ERROR( CV_StsNotImplemented,
  884. "Boosted trees can only be used for 2-class classification." );
  885. CV_CALL( storage = cvCreateMemStorage() );
  886. weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
  887. storage = 0;
  888. }
  889. else
  890. {
  891. data->set_data( _train_data, _tflag, _responses, _var_idx,
  892. _sample_idx, _var_type, _missing_mask, _params, true, true, true );
  893. }
  894. if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )
  895. data->do_responses_copy();
  896. update_weights( 0 );
  897. for( i = 0; i < params.weak_count; i++ )
  898. {
  899. CvBoostTree* tree = new CvBoostTree;
  900. if( !tree->train( data, subsample_mask, this ) )
  901. {
  902. delete tree;
  903. break;
  904. }
  905. //cvCheckArr( get_weak_response());
  906. cvSeqPush( weak, &tree );
  907. update_weights( tree );
  908. trim_weights();
  909. if( cvCountNonZero(subsample_mask) == 0 )
  910. break;
  911. }
  912. if(weak->total > 0)
  913. {
  914. get_active_vars(); // recompute active_vars* maps and condensed_idx's in the splits.
  915. data->is_classifier = true;
  916. data->free_train_data();
  917. ok = true;
  918. }
  919. else
  920. clear();
  921. __END__;
  922. return ok;
  923. }
  924. bool CvBoost::train( CvMLData* _data,
  925. CvBoostParams _params,
  926. bool update )
  927. {
  928. bool result = false;
  929. CV_FUNCNAME( "CvBoost::train" );
  930. __BEGIN__;
  931. const CvMat* values = _data->get_values();
  932. const CvMat* response = _data->get_responses();
  933. const CvMat* missing = _data->get_missing();
  934. const CvMat* var_types = _data->get_var_types();
  935. const CvMat* train_sidx = _data->get_train_sample_idx();
  936. const CvMat* var_idx = _data->get_var_idx();
  937. CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
  938. train_sidx, var_types, missing, _params, update ) );
  939. __END__;
  940. return result;
  941. }
  942. void CvBoost::initialize_weights(double (&p)[2])
  943. {
  944. p[0] = 1.;
  945. p[1] = 1.;
  946. }
  947. void
  948. CvBoost::update_weights( CvBoostTree* tree )
  949. {
  950. CV_FUNCNAME( "CvBoost::update_weights" );
  951. __BEGIN__;
  952. int i, n = data->sample_count;
  953. double sumw = 0.;
  954. int step = 0;
  955. float* fdata = 0;
  956. int *sample_idx_buf;
  957. const int* sample_idx = 0;
  958. cv::AutoBuffer<uchar> inn_buf;
  959. size_t _buf_size = (params.boost_type == LOGIT) || (params.boost_type == GENTLE) ? (size_t)(data->sample_count)*sizeof(int) : 0;
  960. if( !tree )
  961. _buf_size += n*sizeof(int);
  962. else
  963. {
  964. if( have_subsample )
  965. _buf_size += data->get_length_subbuf()*(sizeof(float)+sizeof(uchar));
  966. }
  967. inn_buf.allocate(_buf_size);
  968. uchar* cur_buf_pos = inn_buf.data();
  969. if ( (params.boost_type == LOGIT) || (params.boost_type == GENTLE) )
  970. {
  971. step = CV_IS_MAT_CONT(data->responses_copy->type) ?
  972. 1 : data->responses_copy->step / CV_ELEM_SIZE(data->responses_copy->type);
  973. fdata = data->responses_copy->data.fl;
  974. sample_idx_buf = (int*)cur_buf_pos;
  975. cur_buf_pos = (uchar*)(sample_idx_buf + data->sample_count);
  976. sample_idx = data->get_sample_indices( data->data_root, sample_idx_buf );
  977. }
  978. CvMat* dtree_data_buf = data->buf;
  979. size_t length_buf_row = data->get_length_subbuf();
  980. if( !tree ) // before training the first tree, initialize weights and other parameters
  981. {
  982. int* class_labels_buf = (int*)cur_buf_pos;
  983. cur_buf_pos = (uchar*)(class_labels_buf + n);
  984. const int* class_labels = data->get_class_labels(data->data_root, class_labels_buf);
  985. // in case of logitboost and gentle adaboost each weak tree is a regression tree,
  986. // so we need to convert class labels to floating-point values
  987. double w0 = 1./ n;
  988. double p[2] = { 1., 1. };
  989. initialize_weights(p);
  990. cvReleaseMat( &orig_response );
  991. cvReleaseMat( &sum_response );
  992. cvReleaseMat( &weak_eval );
  993. cvReleaseMat( &subsample_mask );
  994. cvReleaseMat( &weights );
  995. cvReleaseMat( &subtree_weights );
  996. CV_CALL( orig_response = cvCreateMat( 1, n, CV_32S ));
  997. CV_CALL( weak_eval = cvCreateMat( 1, n, CV_64F ));
  998. CV_CALL( subsample_mask = cvCreateMat( 1, n, CV_8U ));
  999. CV_CALL( weights = cvCreateMat( 1, n, CV_64F ));
  1000. CV_CALL( subtree_weights = cvCreateMat( 1, n + 2, CV_64F ));
  1001. if( data->have_priors )
  1002. {
  1003. // compute weight scale for each class from their prior probabilities
  1004. int c1 = 0;
  1005. for( i = 0; i < n; i++ )
  1006. c1 += class_labels[i];
  1007. p[0] = data->priors->data.db[0]*(c1 < n ? 1./(n - c1) : 0.);
  1008. p[1] = data->priors->data.db[1]*(c1 > 0 ? 1./c1 : 0.);
  1009. p[0] /= p[0] + p[1];
  1010. p[1] = 1. - p[0];
  1011. }
  1012. if (data->is_buf_16u)
  1013. {
  1014. unsigned short* labels = (unsigned short*)(dtree_data_buf->data.s + data->data_root->buf_idx*length_buf_row +
  1015. data->data_root->offset + (size_t)(data->work_var_count-1)*data->sample_count);
  1016. for( i = 0; i < n; i++ )
  1017. {
  1018. // save original categorical responses {0,1}, convert them to {-1,1}
  1019. orig_response->data.i[i] = class_labels[i]*2 - 1;
  1020. // make all the samples active at start.
  1021. // later, in trim_weights() deactivate/reactive again some, if need
  1022. subsample_mask->data.ptr[i] = (uchar)1;
  1023. // make all the initial weights the same.
  1024. weights->data.db[i] = w0*p[class_labels[i]];
  1025. // set the labels to find (from within weak tree learning proc)
  1026. // the particular sample weight, and where to store the response.
  1027. labels[i] = (unsigned short)i;
  1028. }
  1029. }
  1030. else
  1031. {
  1032. int* labels = dtree_data_buf->data.i + data->data_root->buf_idx*length_buf_row +
  1033. data->data_root->offset + (size_t)(data->work_var_count-1)*data->sample_count;
  1034. for( i = 0; i < n; i++ )
  1035. {
  1036. // save original categorical responses {0,1}, convert them to {-1,1}
  1037. orig_response->data.i[i] = class_labels[i]*2 - 1;
  1038. // make all the samples active at start.
  1039. // later, in trim_weights() deactivate/reactive again some, if need
  1040. subsample_mask->data.ptr[i] = (uchar)1;
  1041. // make all the initial weights the same.
  1042. weights->data.db[i] = w0*p[class_labels[i]];
  1043. // set the labels to find (from within weak tree learning proc)
  1044. // the particular sample weight, and where to store the response.
  1045. labels[i] = i;
  1046. }
  1047. }
  1048. if( params.boost_type == LOGIT )
  1049. {
  1050. CV_CALL( sum_response = cvCreateMat( 1, n, CV_64F ));
  1051. for( i = 0; i < n; i++ )
  1052. {
  1053. sum_response->data.db[i] = 0;
  1054. fdata[sample_idx[i]*step] = orig_response->data.i[i] > 0 ? 2.f : -2.f;
  1055. }
  1056. // in case of logitboost each weak tree is a regression tree.
  1057. // the target function values are recalculated for each of the trees
  1058. data->is_classifier = false;
  1059. }
  1060. else if( params.boost_type == GENTLE )
  1061. {
  1062. for( i = 0; i < n; i++ )
  1063. fdata[sample_idx[i]*step] = (float)orig_response->data.i[i];
  1064. data->is_classifier = false;
  1065. }
  1066. }
  1067. else
  1068. {
  1069. // at this moment, for all the samples that participated in the training of the most
  1070. // recent weak classifier we know the responses. For other samples we need to compute them
  1071. if( have_subsample )
  1072. {
  1073. float* values = (float*)cur_buf_pos;
  1074. cur_buf_pos = (uchar*)(values + data->get_length_subbuf());
  1075. uchar* missing = cur_buf_pos;
  1076. cur_buf_pos = missing + data->get_length_subbuf() * (size_t)CV_ELEM_SIZE(data->buf->type);
  1077. CvMat _sample, _mask;
  1078. // invert the subsample mask
  1079. cvXorS( subsample_mask, cvScalar(1.), subsample_mask );
  1080. data->get_vectors( subsample_mask, values, missing, 0 );
  1081. _sample = cvMat( 1, data->var_count, CV_32F );
  1082. _mask = cvMat( 1, data->var_count, CV_8U );
  1083. // run tree through all the non-processed samples
  1084. for( i = 0; i < n; i++ )
  1085. if( subsample_mask->data.ptr[i] )
  1086. {
  1087. _sample.data.fl = values;
  1088. _mask.data.ptr = missing;
  1089. values += _sample.cols;
  1090. missing += _mask.cols;
  1091. weak_eval->data.db[i] = tree->predict( &_sample, &_mask, true )->value;
  1092. }
  1093. }
  1094. // now update weights and other parameters for each type of boosting
  1095. if( params.boost_type == DISCRETE )
  1096. {
  1097. // Discrete AdaBoost:
  1098. // weak_eval[i] (=f(x_i)) is in {-1,1}
  1099. // err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
  1100. // C = log((1-err)/err)
  1101. // w_i *= exp(C*(f(x_i) != y_i))
  1102. double C, err = 0.;
  1103. double scale[] = { 1., 0. };
  1104. for( i = 0; i < n; i++ )
  1105. {
  1106. double w = weights->data.db[i];
  1107. sumw += w;
  1108. err += w*(weak_eval->data.db[i] != orig_response->data.i[i]);
  1109. }
  1110. if( sumw != 0 )
  1111. err /= sumw;
  1112. C = err = -log_ratio( err );
  1113. scale[1] = exp(err);
  1114. sumw = 0;
  1115. for( i = 0; i < n; i++ )
  1116. {
  1117. double w = weights->data.db[i]*
  1118. scale[weak_eval->data.db[i] != orig_response->data.i[i]];
  1119. sumw += w;
  1120. weights->data.db[i] = w;
  1121. }
  1122. tree->scale( C );
  1123. }
  1124. else if( params.boost_type == REAL )
  1125. {
  1126. // Real AdaBoost:
  1127. // weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
  1128. // w_i *= exp(-y_i*f(x_i))
  1129. for( i = 0; i < n; i++ )
  1130. weak_eval->data.db[i] *= -orig_response->data.i[i];
  1131. cvExp( weak_eval, weak_eval );
  1132. for( i = 0; i < n; i++ )
  1133. {
  1134. double w = weights->data.db[i]*weak_eval->data.db[i];
  1135. sumw += w;
  1136. weights->data.db[i] = w;
  1137. }
  1138. }
  1139. else if( params.boost_type == LOGIT )
  1140. {
  1141. // LogitBoost:
  1142. // weak_eval[i] = f(x_i) in [-z_max,z_max]
  1143. // sum_response = F(x_i).
  1144. // F(x_i) += 0.5*f(x_i)
  1145. // p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))
  1146. // reuse weak_eval: weak_eval[i] <- p(x_i)
  1147. // w_i = p(x_i)*1(1 - p(x_i))
  1148. // z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
  1149. // store z_i to the data->data_root as the new target responses
  1150. const double lb_weight_thresh = FLT_EPSILON;
  1151. const double lb_z_max = 10.;
  1152. /*float* responses_buf = data->get_resp_float_buf();
  1153. const float* responses = 0;
  1154. data->get_ord_responses(data->data_root, responses_buf, &responses);*/
  1155. /*if( weak->total == 7 )
  1156. putchar('*');*/
  1157. for( i = 0; i < n; i++ )
  1158. {
  1159. double s = sum_response->data.db[i] + 0.5*weak_eval->data.db[i];
  1160. sum_response->data.db[i] = s;
  1161. weak_eval->data.db[i] = -2*s;
  1162. }
  1163. cvExp( weak_eval, weak_eval );
  1164. for( i = 0; i < n; i++ )
  1165. {
  1166. double p = 1./(1. + weak_eval->data.db[i]);
  1167. double w = p*(1 - p), z;
  1168. w = MAX( w, lb_weight_thresh );
  1169. weights->data.db[i] = w;
  1170. sumw += w;
  1171. if( orig_response->data.i[i] > 0 )
  1172. {
  1173. z = 1./p;
  1174. fdata[sample_idx[i]*step] = (float)MIN(z, lb_z_max);
  1175. }
  1176. else
  1177. {
  1178. z = 1./(1-p);
  1179. fdata[sample_idx[i]*step] = (float)-MIN(z, lb_z_max);
  1180. }
  1181. }
  1182. }
  1183. else
  1184. {
  1185. // Gentle AdaBoost:
  1186. // weak_eval[i] = f(x_i) in [-1,1]
  1187. // w_i *= exp(-y_i*f(x_i))
  1188. assert( params.boost_type == GENTLE );
  1189. for( i = 0; i < n; i++ )
  1190. weak_eval->data.db[i] *= -orig_response->data.i[i];
  1191. cvExp( weak_eval, weak_eval );
  1192. for( i = 0; i < n; i++ )
  1193. {
  1194. double w = weights->data.db[i] * weak_eval->data.db[i];
  1195. weights->data.db[i] = w;
  1196. sumw += w;
  1197. }
  1198. }
  1199. }
  1200. // renormalize weights
  1201. if( sumw > FLT_EPSILON )
  1202. {
  1203. sumw = 1./sumw;
  1204. for( i = 0; i < n; ++i )
  1205. weights->data.db[i] *= sumw;
  1206. }
  1207. __END__;
  1208. }
  1209. void
  1210. CvBoost::trim_weights()
  1211. {
  1212. //CV_FUNCNAME( "CvBoost::trim_weights" );
  1213. __BEGIN__;
  1214. int i, count = data->sample_count, nz_count = 0;
  1215. double sum, threshold;
  1216. if( params.weight_trim_rate <= 0. || params.weight_trim_rate >= 1. )
  1217. EXIT;
  1218. // use weak_eval as temporary buffer for sorted weights
  1219. cvCopy( weights, weak_eval );
  1220. std::sort(weak_eval->data.db, weak_eval->data.db + count);
  1221. // as weight trimming occurs immediately after updating the weights,
  1222. // where they are renormalized, we assume that the weight sum = 1.
  1223. sum = 1. - params.weight_trim_rate;
  1224. for( i = 0; i < count; i++ )
  1225. {
  1226. double w = weak_eval->data.db[i];
  1227. if( sum <= 0 )
  1228. break;
  1229. sum -= w;
  1230. }
  1231. threshold = i < count ? weak_eval->data.db[i] : DBL_MAX;
  1232. for( i = 0; i < count; i++ )
  1233. {
  1234. double w = weights->data.db[i];
  1235. int f = w >= threshold;
  1236. subsample_mask->data.ptr[i] = (uchar)f;
  1237. nz_count += f;
  1238. }
  1239. have_subsample = nz_count < count;
  1240. __END__;
  1241. }
  1242. const CvMat*
  1243. CvBoost::get_active_vars( bool absolute_idx )
  1244. {
  1245. CvMat* mask = 0;
  1246. CvMat* inv_map = 0;
  1247. CvMat* result = 0;
  1248. CV_FUNCNAME( "CvBoost::get_active_vars" );
  1249. __BEGIN__;
  1250. if( !weak )
  1251. CV_ERROR( CV_StsError, "The boosted tree ensemble has not been trained yet" );
  1252. if( !active_vars || !active_vars_abs )
  1253. {
  1254. CvSeqReader reader;
  1255. int i, j, nactive_vars;
  1256. CvBoostTree* wtree;
  1257. const CvDTreeNode* node;
  1258. assert(!active_vars && !active_vars_abs);
  1259. mask = cvCreateMat( 1, data->var_count, CV_8U );
  1260. inv_map = cvCreateMat( 1, data->var_count, CV_32S );
  1261. cvZero( mask );
  1262. cvSet( inv_map, cvScalar(-1) );
  1263. // first pass: compute the mask of used variables
  1264. cvStartReadSeq( weak, &reader );
  1265. for( i = 0; i < weak->total; i++ )
  1266. {
  1267. CV_READ_SEQ_ELEM(wtree, reader);
  1268. node = wtree->get_root();
  1269. assert( node != 0 );
  1270. for(;;)
  1271. {
  1272. const CvDTreeNode* parent;
  1273. for(;;)
  1274. {
  1275. CvDTreeSplit* split = node->split;
  1276. for( ; split != 0; split = split->next )
  1277. mask->data.ptr[split->var_idx] = 1;
  1278. if( !node->left )
  1279. break;
  1280. node = node->left;
  1281. }
  1282. for( parent = node->parent; parent && parent->right == node;
  1283. node = parent, parent = parent->parent )
  1284. ;
  1285. if( !parent )
  1286. break;
  1287. node = parent->right;
  1288. }
  1289. }
  1290. nactive_vars = cvCountNonZero(mask);
  1291. //if ( nactive_vars > 0 )
  1292. {
  1293. active_vars = cvCreateMat( 1, nactive_vars, CV_32S );
  1294. active_vars_abs = cvCreateMat( 1, nactive_vars, CV_32S );
  1295. have_active_cat_vars = false;
  1296. for( i = j = 0; i < data->var_count; i++ )
  1297. {
  1298. if( mask->data.ptr[i] )
  1299. {
  1300. active_vars->data.i[j] = i;
  1301. active_vars_abs->data.i[j] = data->var_idx ? data->var_idx->data.i[i] : i;
  1302. inv_map->data.i[i] = j;
  1303. if( data->var_type->data.i[i] >= 0 )
  1304. have_active_cat_vars = true;
  1305. j++;
  1306. }
  1307. }
  1308. // second pass: now compute the condensed indices
  1309. cvStartReadSeq( weak, &reader );
  1310. for( i = 0; i < weak->total; i++ )
  1311. {
  1312. CV_READ_SEQ_ELEM(wtree, reader);
  1313. node = wtree->get_root();
  1314. for(;;)
  1315. {
  1316. const CvDTreeNode* parent;
  1317. for(;;)
  1318. {
  1319. CvDTreeSplit* split = node->split;
  1320. for( ; split != 0; split = split->next )
  1321. {
  1322. split->condensed_idx = inv_map->data.i[split->var_idx];
  1323. assert( split->condensed_idx >= 0 );
  1324. }
  1325. if( !node->left )
  1326. break;
  1327. node = node->left;
  1328. }
  1329. for( parent = node->parent; parent && parent->right == node;
  1330. node = parent, parent = parent->parent )
  1331. ;
  1332. if( !parent )
  1333. break;
  1334. node = parent->right;
  1335. }
  1336. }
  1337. }
  1338. }
  1339. result = absolute_idx ? active_vars_abs : active_vars;
  1340. __END__;
  1341. cvReleaseMat( &mask );
  1342. cvReleaseMat( &inv_map );
  1343. return result;
  1344. }
  1345. float
  1346. CvBoost::predict( const CvMat* _sample, const CvMat* _missing,
  1347. CvMat* weak_responses, CvSlice slice,
  1348. bool raw_mode, bool return_sum ) const
  1349. {
  1350. float value = -FLT_MAX;
  1351. CvSeqReader reader;
  1352. double sum = 0;
  1353. int wstep = 0;
  1354. const float* sample_data;
  1355. if( !weak )
  1356. CV_Error( CV_StsError, "The boosted tree ensemble has not been trained yet" );
  1357. if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
  1358. (_sample->cols != 1 && _sample->rows != 1) ||
  1359. (_sample->cols + _sample->rows - 1 != data->var_all && !raw_mode) ||
  1360. (active_vars && _sample->cols + _sample->rows - 1 != active_vars->cols && raw_mode) )
  1361. CV_Error( CV_StsBadArg,
  1362. "the input sample must be 1d floating-point vector with the same "
  1363. "number of elements as the total number of variables or "
  1364. "as the number of variables used for training" );
  1365. if( _missing )
  1366. {
  1367. if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
  1368. !CV_ARE_SIZES_EQ(_missing, _sample) )
  1369. CV_Error( CV_StsBadArg,
  1370. "the missing data mask must be 8-bit vector of the same size as input sample" );
  1371. }
  1372. int i, weak_count = cvSliceLength( slice, weak );
  1373. if( weak_count >= weak->total )
  1374. {
  1375. weak_count = weak->total;
  1376. slice.start_index = 0;
  1377. }
  1378. if( weak_responses )
  1379. {
  1380. if( !CV_IS_MAT(weak_responses) ||
  1381. CV_MAT_TYPE(weak_responses->type) != CV_32FC1 ||
  1382. (weak_responses->cols != 1 && weak_responses->rows != 1) ||
  1383. weak_responses->cols + weak_responses->rows - 1 != weak_count )
  1384. CV_Error( CV_StsBadArg,
  1385. "The output matrix of weak classifier responses must be valid "
  1386. "floating-point vector of the same number of components as the length of input slice" );
  1387. wstep = CV_IS_MAT_CONT(weak_responses->type) ? 1 : weak_responses->step/sizeof(float);
  1388. }
  1389. int var_count = active_vars->cols;
  1390. const int* vtype = data->var_type->data.i;
  1391. const int* cmap = data->cat_map->data.i;
  1392. const int* cofs = data->cat_ofs->data.i;
  1393. cv::Mat sample = cv::cvarrToMat(_sample);
  1394. cv::Mat missing;
  1395. if(!_missing)
  1396. missing = cv::cvarrToMat(_missing);
  1397. // if need, preprocess the input vector
  1398. if( !raw_mode )
  1399. {
  1400. int sstep, mstep = 0;
  1401. const float* src_sample;
  1402. const uchar* src_mask = 0;
  1403. float* dst_sample;
  1404. uchar* dst_mask;
  1405. const int* vidx = active_vars->data.i;
  1406. const int* vidx_abs = active_vars_abs->data.i;
  1407. bool have_mask = _missing != 0;
  1408. sample = cv::Mat(1, var_count, CV_32FC1);
  1409. missing = cv::Mat(1, var_count, CV_8UC1);
  1410. dst_sample = sample.ptr<float>();
  1411. dst_mask = missing.ptr<uchar>();
  1412. src_sample = _sample->data.fl;
  1413. sstep = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(src_sample[0]);
  1414. if( _missing )
  1415. {
  1416. src_mask = _missing->data.ptr;
  1417. mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step;
  1418. }
  1419. for( i = 0; i < var_count; i++ )
  1420. {
  1421. int idx = vidx[i], idx_abs = vidx_abs[i];
  1422. float val = src_sample[idx_abs*sstep];
  1423. int ci = vtype[idx];
  1424. uchar m = src_mask ? src_mask[idx_abs*mstep] : (uchar)0;
  1425. if( ci >= 0 )
  1426. {
  1427. int a = cofs[ci], b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1],
  1428. c = a;
  1429. int ival = cvRound(val);
  1430. if ( (ival != val) && (!m) )
  1431. CV_Error( CV_StsBadArg,
  1432. "one of input categorical variable is not an integer" );
  1433. while( a < b )
  1434. {
  1435. c = (a + b) >> 1;
  1436. if( ival < cmap[c] )
  1437. b = c;
  1438. else if( ival > cmap[c] )
  1439. a = c+1;
  1440. else
  1441. break;
  1442. }
  1443. if( c < 0 || ival != cmap[c] )
  1444. {
  1445. m = 1;
  1446. have_mask = true;
  1447. }
  1448. else
  1449. {
  1450. val = (float)(c - cofs[ci]);
  1451. }
  1452. }
  1453. dst_sample[i] = val;
  1454. dst_mask[i] = m;
  1455. }
  1456. if( !have_mask )
  1457. missing.release();
  1458. }
  1459. else
  1460. {
  1461. if( !CV_IS_MAT_CONT(_sample->type & (_missing ? _missing->type : -1)) )
  1462. CV_Error( CV_StsBadArg, "In raw mode the input vectors must be continuous" );
  1463. }
  1464. cvStartReadSeq( weak, &reader );
  1465. cvSetSeqReaderPos( &reader, slice.start_index );
  1466. sample_data = sample.ptr<float>();
  1467. if( !have_active_cat_vars && missing.empty() && !weak_responses )
  1468. {
  1469. for( i = 0; i < weak_count; i++ )
  1470. {
  1471. CvBoostTree* wtree;
  1472. const CvDTreeNode* node;
  1473. CV_READ_SEQ_ELEM( wtree, reader );
  1474. node = wtree->get_root();
  1475. while( node->left )
  1476. {
  1477. CvDTreeSplit* split = node->split;
  1478. int vi = split->condensed_idx;
  1479. float val = sample_data[vi];
  1480. int dir = val <= split->ord.c ? -1 : 1;
  1481. if( split->inversed )
  1482. dir = -dir;
  1483. node = dir < 0 ? node->left : node->right;
  1484. }
  1485. sum += node->value;
  1486. }
  1487. }
  1488. else
  1489. {
  1490. const int* avars = active_vars->data.i;
  1491. const uchar* m = !missing.empty() ? missing.ptr<uchar>() : 0;
  1492. // full-featured version
  1493. for( i = 0; i < weak_count; i++ )
  1494. {
  1495. CvBoostTree* wtree;
  1496. const CvDTreeNode* node;
  1497. CV_READ_SEQ_ELEM( wtree, reader );
  1498. node = wtree->get_root();
  1499. while( node->left )
  1500. {
  1501. const CvDTreeSplit* split = node->split;
  1502. int dir = 0;
  1503. for( ; !dir && split != 0; split = split->next )
  1504. {
  1505. int vi = split->condensed_idx;
  1506. int ci = vtype[avars[vi]];
  1507. float val = sample_data[vi];
  1508. if( m && m[vi] )
  1509. continue;
  1510. if( ci < 0 ) // ordered
  1511. dir = val <= split->ord.c ? -1 : 1;
  1512. else // categorical
  1513. {
  1514. int c = cvRound(val);
  1515. dir = CV_DTREE_CAT_DIR(c, split->subset);
  1516. }
  1517. if( split->inversed )
  1518. dir = -dir;
  1519. }
  1520. if( !dir )
  1521. {
  1522. int diff = node->right->sample_count - node->left->sample_count;
  1523. dir = diff < 0 ? -1 : 1;
  1524. }
  1525. node = dir < 0 ? node->left : node->right;
  1526. }
  1527. if( weak_responses )
  1528. weak_responses->data.fl[i*wstep] = (float)node->value;
  1529. sum += node->value;
  1530. }
  1531. }
  1532. if( return_sum )
  1533. value = (float)sum;
  1534. else
  1535. {
  1536. int cls_idx = sum >= 0;
  1537. if( raw_mode )
  1538. value = (float)cls_idx;
  1539. else
  1540. value = (float)cmap[cofs[vtype[data->var_count]] + cls_idx];
  1541. }
  1542. return value;
  1543. }
  1544. float CvBoost::calc_error( CvMLData* _data, int type, std::vector<float> *resp )
  1545. {
  1546. float err = 0;
  1547. const CvMat* values = _data->get_values();
  1548. const CvMat* response = _data->get_responses();
  1549. const CvMat* missing = _data->get_missing();
  1550. const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
  1551. const CvMat* var_types = _data->get_var_types();
  1552. int* sidx = sample_idx ? sample_idx->data.i : 0;
  1553. int r_step = CV_IS_MAT_CONT(response->type) ?
  1554. 1 : response->step / CV_ELEM_SIZE(response->type);
  1555. bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
  1556. int sample_count = sample_idx ? sample_idx->cols : 0;
  1557. sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
  1558. float* pred_resp = 0;
  1559. if( resp && (sample_count > 0) )
  1560. {
  1561. resp->resize( sample_count );
  1562. pred_resp = &((*resp)[0]);
  1563. }
  1564. if ( is_classifier )
  1565. {
  1566. for( int i = 0; i < sample_count; i++ )
  1567. {
  1568. CvMat sample, miss;
  1569. int si = sidx ? sidx[i] : i;
  1570. cvGetRow( values, &sample, si );
  1571. if( missing )
  1572. cvGetRow( missing, &miss, si );
  1573. float r = (float)predict( &sample, missing ? &miss : 0 );
  1574. if( pred_resp )
  1575. pred_resp[i] = r;
  1576. int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
  1577. err += d;
  1578. }
  1579. err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
  1580. }
  1581. else
  1582. {
  1583. for( int i = 0; i < sample_count; i++ )
  1584. {
  1585. CvMat sample, miss;
  1586. int si = sidx ? sidx[i] : i;
  1587. cvGetRow( values, &sample, si );
  1588. if( missing )
  1589. cvGetRow( missing, &miss, si );
  1590. float r = (float)predict( &sample, missing ? &miss : 0 );
  1591. if( pred_resp )
  1592. pred_resp[i] = r;
  1593. float d = r - response->data.fl[si*r_step];
  1594. err += d*d;
  1595. }
  1596. err = sample_count ? err / (float)sample_count : -FLT_MAX;
  1597. }
  1598. return err;
  1599. }
  1600. void CvBoost::write_params( CvFileStorage* fs ) const
  1601. {
  1602. const char* boost_type_str =
  1603. params.boost_type == DISCRETE ? "DiscreteAdaboost" :
  1604. params.boost_type == REAL ? "RealAdaboost" :
  1605. params.boost_type == LOGIT ? "LogitBoost" :
  1606. params.boost_type == GENTLE ? "GentleAdaboost" : 0;
  1607. const char* split_crit_str =
  1608. params.split_criteria == DEFAULT ? "Default" :
  1609. params.split_criteria == GINI ? "Gini" :
  1610. params.boost_type == MISCLASS ? "Misclassification" :
  1611. params.boost_type == SQERR ? "SquaredErr" : 0;
  1612. if( boost_type_str )
  1613. cvWriteString( fs, "boosting_type", boost_type_str );
  1614. else
  1615. cvWriteInt( fs, "boosting_type", params.boost_type );
  1616. if( split_crit_str )
  1617. cvWriteString( fs, "splitting_criteria", split_crit_str );
  1618. else
  1619. cvWriteInt( fs, "splitting_criteria", params.split_criteria );
  1620. cvWriteInt( fs, "ntrees", weak->total );
  1621. cvWriteReal( fs, "weight_trimming_rate", params.weight_trim_rate );
  1622. data->write_params( fs );
  1623. }
  1624. void CvBoost::read_params( CvFileStorage* fs, CvFileNode* fnode )
  1625. {
  1626. CV_FUNCNAME( "CvBoost::read_params" );
  1627. __BEGIN__;
  1628. CvFileNode* temp;
  1629. if( !fnode || !CV_NODE_IS_MAP(fnode->tag) )
  1630. return;
  1631. data = new CvDTreeTrainData();
  1632. CV_CALL( data->read_params(fs, fnode));
  1633. data->shared = true;
  1634. params.max_depth = data->params.max_depth;
  1635. params.min_sample_count = data->params.min_sample_count;
  1636. params.max_categories = data->params.max_categories;
  1637. params.priors = data->params.priors;
  1638. params.regression_accuracy = data->params.regression_accuracy;
  1639. params.use_surrogates = data->params.use_surrogates;
  1640. temp = cvGetFileNodeByName( fs, fnode, "boosting_type" );
  1641. if( !temp )
  1642. return;
  1643. if( temp && CV_NODE_IS_STRING(temp->tag) )
  1644. {
  1645. const char* boost_type_str = cvReadString( temp, "" );
  1646. params.boost_type = strcmp( boost_type_str, "DiscreteAdaboost" ) == 0 ? DISCRETE :
  1647. strcmp( boost_type_str, "RealAdaboost" ) == 0 ? REAL :
  1648. strcmp( boost_type_str, "LogitBoost" ) == 0 ? LOGIT :
  1649. strcmp( boost_type_str, "GentleAdaboost" ) == 0 ? GENTLE : -1;
  1650. }
  1651. else
  1652. params.boost_type = cvReadInt( temp, -1 );
  1653. if( params.boost_type < DISCRETE || params.boost_type > GENTLE )
  1654. CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
  1655. temp = cvGetFileNodeByName( fs, fnode, "splitting_criteria" );
  1656. if( temp && CV_NODE_IS_STRING(temp->tag) )
  1657. {
  1658. const char* split_crit_str = cvReadString( temp, "" );
  1659. params.split_criteria = strcmp( split_crit_str, "Default" ) == 0 ? DEFAULT :
  1660. strcmp( split_crit_str, "Gini" ) == 0 ? GINI :
  1661. strcmp( split_crit_str, "Misclassification" ) == 0 ? MISCLASS :
  1662. strcmp( split_crit_str, "SquaredErr" ) == 0 ? SQERR : -1;
  1663. }
  1664. else
  1665. params.split_criteria = cvReadInt( temp, -1 );
  1666. if( params.split_criteria < DEFAULT || params.boost_type > SQERR )
  1667. CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
  1668. params.weak_count = cvReadIntByName( fs, fnode, "ntrees" );
  1669. params.weight_trim_rate = cvReadRealByName( fs, fnode, "weight_trimming_rate", 0. );
  1670. __END__;
  1671. }
  1672. void
  1673. CvBoost::read( CvFileStorage* fs, CvFileNode* node )
  1674. {
  1675. CV_FUNCNAME( "CvBoost::read" );
  1676. __BEGIN__;
  1677. CvSeqReader reader;
  1678. CvFileNode* trees_fnode;
  1679. CvMemStorage* storage;
  1680. int i, ntrees;
  1681. clear();
  1682. read_params( fs, node );
  1683. if( !data )
  1684. EXIT;
  1685. trees_fnode = cvGetFileNodeByName( fs, node, "trees" );
  1686. if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
  1687. CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
  1688. cvStartReadSeq( trees_fnode->data.seq, &reader );
  1689. ntrees = trees_fnode->data.seq->total;
  1690. if( ntrees != params.weak_count )
  1691. CV_ERROR( CV_StsUnmatchedSizes,
  1692. "The number of trees stored does not match <ntrees> tag value" );
  1693. CV_CALL( storage = cvCreateMemStorage() );
  1694. weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
  1695. for( i = 0; i < ntrees; i++ )
  1696. {
  1697. CvBoostTree* tree = new CvBoostTree();
  1698. CV_CALL(tree->read( fs, (CvFileNode*)reader.ptr, this, data ));
  1699. CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
  1700. cvSeqPush( weak, &tree );
  1701. }
  1702. get_active_vars();
  1703. __END__;
  1704. }
  1705. void
  1706. CvBoost::write( CvFileStorage* fs, const char* name ) const
  1707. {
  1708. CV_FUNCNAME( "CvBoost::write" );
  1709. __BEGIN__;
  1710. CvSeqReader reader;
  1711. int i;
  1712. cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_BOOSTING );
  1713. if( !weak )
  1714. CV_ERROR( CV_StsBadArg, "The classifier has not been trained yet" );
  1715. write_params( fs );
  1716. cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
  1717. cvStartReadSeq( weak, &reader );
  1718. for( i = 0; i < weak->total; i++ )
  1719. {
  1720. CvBoostTree* tree;
  1721. CV_READ_SEQ_ELEM( tree, reader );
  1722. cvStartWriteStruct( fs, 0, CV_NODE_MAP );
  1723. tree->write( fs );
  1724. cvEndWriteStruct( fs );
  1725. }
  1726. cvEndWriteStruct( fs );
  1727. cvEndWriteStruct( fs );
  1728. __END__;
  1729. }
  1730. CvMat*
  1731. CvBoost::get_weights()
  1732. {
  1733. return weights;
  1734. }
  1735. CvMat*
  1736. CvBoost::get_subtree_weights()
  1737. {
  1738. return subtree_weights;
  1739. }
  1740. CvMat*
  1741. CvBoost::get_weak_response()
  1742. {
  1743. return weak_eval;
  1744. }
  1745. const CvBoostParams&
  1746. CvBoost::get_params() const
  1747. {
  1748. return params;
  1749. }
  1750. CvSeq* CvBoost::get_weak_predictors()
  1751. {
  1752. return weak;
  1753. }
  1754. const CvDTreeTrainData* CvBoost::get_data() const
  1755. {
  1756. return data;
  1757. }
  1758. using namespace cv;
  1759. CvBoost::CvBoost( const Mat& _train_data, int _tflag,
  1760. const Mat& _responses, const Mat& _var_idx,
  1761. const Mat& _sample_idx, const Mat& _var_type,
  1762. const Mat& _missing_mask,
  1763. CvBoostParams _params )
  1764. {
  1765. weak = 0;
  1766. data = 0;
  1767. default_model_name = "my_boost_tree";
  1768. active_vars = active_vars_abs = orig_response = sum_response = weak_eval =
  1769. subsample_mask = weights = subtree_weights = 0;
  1770. train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
  1771. _var_type, _missing_mask, _params );
  1772. }
  1773. bool
  1774. CvBoost::train( const Mat& _train_data, int _tflag,
  1775. const Mat& _responses, const Mat& _var_idx,
  1776. const Mat& _sample_idx, const Mat& _var_type,
  1777. const Mat& _missing_mask,
  1778. CvBoostParams _params, bool _update )
  1779. {
  1780. train_data_hdr = cvMat(_train_data);
  1781. train_data_mat = _train_data;
  1782. responses_hdr = cvMat(_responses);
  1783. responses_mat = _responses;
  1784. CvMat vidx = cvMat(_var_idx), sidx = cvMat(_sample_idx), vtype = cvMat(_var_type), mmask = cvMat(_missing_mask);
  1785. return train(&train_data_hdr, _tflag, &responses_hdr, vidx.data.ptr ? &vidx : 0,
  1786. sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
  1787. mmask.data.ptr ? &mmask : 0, _params, _update);
  1788. }
  1789. float
  1790. CvBoost::predict( const Mat& _sample, const Mat& _missing,
  1791. const Range& slice, bool raw_mode, bool return_sum ) const
  1792. {
  1793. CvMat sample = cvMat(_sample), mmask = cvMat(_missing);
  1794. /*if( weak_responses )
  1795. {
  1796. int weak_count = cvSliceLength( slice, weak );
  1797. if( weak_count >= weak->total )
  1798. {
  1799. weak_count = weak->total;
  1800. slice.start_index = 0;
  1801. }
  1802. if( !(weak_responses->data && weak_responses->type() == CV_32FC1 &&
  1803. (weak_responses->cols == 1 || weak_responses->rows == 1) &&
  1804. weak_responses->cols + weak_responses->rows - 1 == weak_count) )
  1805. weak_responses->create(weak_count, 1, CV_32FC1);
  1806. pwr = &(wr = *weak_responses);
  1807. }*/
  1808. return predict(&sample, _missing.empty() ? 0 : &mmask, 0,
  1809. slice == Range::all() ? CV_WHOLE_SEQ : cvSlice(slice.start, slice.end),
  1810. raw_mode, return_sum);
  1811. }
  1812. /* End of file. */