ml.h 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477
  1. /*M///////////////////////////////////////////////////////////////////////////////////////
  2. //
  3. // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
  4. //
  5. // By downloading, copying, installing or using the software you agree to this license.
  6. // If you do not agree to this license, do not download, install,
  7. // copy or use the software.
  8. //
  9. //
  10. // Intel License Agreement
  11. //
  12. // Copyright (C) 2000, Intel Corporation, all rights reserved.
  13. // Third party copyrights are property of their respective owners.
  14. //
  15. // Redistribution and use in source and binary forms, with or without modification,
  16. // are permitted provided that the following conditions are met:
  17. //
  18. // * Redistribution's of source code must retain the above copyright notice,
  19. // this list of conditions and the following disclaimer.
  20. //
  21. // * Redistribution's in binary form must reproduce the above copyright notice,
  22. // this list of conditions and the following disclaimer in the documentation
  23. // and/or other materials provided with the distribution.
  24. //
  25. // * The name of Intel Corporation may not be used to endorse or promote products
  26. // derived from this software without specific prior written permission.
  27. //
  28. // This software is provided by the copyright holders and contributors "as is" and
  29. // any express or implied warranties, including, but not limited to, the implied
  30. // warranties of merchantability and fitness for a particular purpose are disclaimed.
  31. // In no event shall the Intel Corporation or contributors be liable for any direct,
  32. // indirect, incidental, special, exemplary, or consequential damages
  33. // (including, but not limited to, procurement of substitute goods or services;
  34. // loss of use, data, or profits; or business interruption) however caused
  35. // and on any theory of liability, whether in contract, strict liability,
  36. // or tort (including negligence or otherwise) arising in any way out of
  37. // the use of this software, even if advised of the possibility of such damage.
  38. //
  39. //M*/
  40. #ifndef __ML_H__
  41. #define __ML_H__
  42. // disable deprecation warning which appears in VisualStudio 8.0
  43. #if _MSC_VER >= 1400
  44. #pragma warning( disable : 4996 )
  45. #endif
  46. #include <cxcore.h>
  47. #include <limits.h>
  48. #ifdef __cplusplus
  49. extern "C" {
  50. #endif
  51. /****************************************************************************************\
  52. * Main struct definitions *
  53. \****************************************************************************************/
  54. /* log(2*PI) */
  55. #define CV_LOG2PI (1.8378770664093454835606594728112)
  56. /* columns of <trainData> matrix are training samples */
  57. #define CV_COL_SAMPLE 0
  58. /* rows of <trainData> matrix are training samples */
  59. #define CV_ROW_SAMPLE 1
  60. #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
  61. struct CvVectors
  62. {
  63. int type;
  64. int dims, count;
  65. CvVectors* next;
  66. union
  67. {
  68. uchar** ptr;
  69. float** fl;
  70. double** db;
  71. } data;
  72. };
  73. #if 0
  74. /* A structure, representing the lattice range of statmodel parameters.
  75. It is used for optimizing statmodel parameters by cross-validation method.
  76. The lattice is logarithmic, so <step> must be greater then 1. */
  77. typedef struct CvParamLattice
  78. {
  79. double min_val;
  80. double max_val;
  81. double step;
  82. }
  83. CvParamLattice;
  84. CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
  85. double log_step )
  86. {
  87. CvParamLattice pl;
  88. pl.min_val = MIN( min_val, max_val );
  89. pl.max_val = MAX( min_val, max_val );
  90. pl.step = MAX( log_step, 1. );
  91. return pl;
  92. }
  93. CV_INLINE CvParamLattice cvDefaultParamLattice( void )
  94. {
  95. CvParamLattice pl = {0,0,0};
  96. return pl;
  97. }
  98. #endif
  99. /* Variable type */
  100. #define CV_VAR_NUMERICAL 0
  101. #define CV_VAR_ORDERED 0
  102. #define CV_VAR_CATEGORICAL 1
  103. #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm"
  104. #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn"
  105. #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian"
  106. #define CV_TYPE_NAME_ML_EM "opencv-ml-em"
  107. #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree"
  108. #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree"
  109. #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
  110. #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
  111. #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
  112. class CV_EXPORTS CvStatModel
  113. {
  114. public:
  115. CvStatModel();
  116. virtual ~CvStatModel();
  117. virtual void clear();
  118. virtual void save( const char* filename, const char* name=0 );
  119. virtual void load( const char* filename, const char* name=0 );
  120. virtual void write( CvFileStorage* storage, const char* name );
  121. virtual void read( CvFileStorage* storage, CvFileNode* node );
  122. protected:
  123. const char* default_model_name;
  124. };
  125. /****************************************************************************************\
  126. * Normal Bayes Classifier *
  127. \****************************************************************************************/
  128. class CV_EXPORTS CvNormalBayesClassifier : public CvStatModel
  129. {
  130. public:
  131. CvNormalBayesClassifier();
  132. virtual ~CvNormalBayesClassifier();
  133. CvNormalBayesClassifier( const CvMat* _train_data, const CvMat* _responses,
  134. const CvMat* _var_idx=0, const CvMat* _sample_idx=0 );
  135. virtual bool train( const CvMat* _train_data, const CvMat* _responses,
  136. const CvMat* _var_idx = 0, const CvMat* _sample_idx=0, bool update=false );
  137. virtual float predict( const CvMat* _samples, CvMat* results=0 ) const;
  138. virtual void clear();
  139. virtual void write( CvFileStorage* storage, const char* name );
  140. virtual void read( CvFileStorage* storage, CvFileNode* node );
  141. protected:
  142. int var_count, var_all;
  143. CvMat* var_idx;
  144. CvMat* cls_labels;
  145. CvMat** count;
  146. CvMat** sum;
  147. CvMat** productsum;
  148. CvMat** avg;
  149. CvMat** inv_eigen_values;
  150. CvMat** cov_rotate_mats;
  151. CvMat* c;
  152. };
  153. /****************************************************************************************\
  154. * K-Nearest Neighbour Classifier *
  155. \****************************************************************************************/
  156. // k Nearest Neighbors
  157. class CV_EXPORTS CvKNearest : public CvStatModel
  158. {
  159. public:
  160. CvKNearest();
  161. virtual ~CvKNearest();
  162. CvKNearest( const CvMat* _train_data, const CvMat* _responses,
  163. const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );
  164. virtual bool train( const CvMat* _train_data, const CvMat* _responses,
  165. const CvMat* _sample_idx=0, bool is_regression=false,
  166. int _max_k=32, bool _update_base=false );
  167. virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0,
  168. const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
  169. virtual void clear();
  170. int get_max_k() const;
  171. int get_var_count() const;
  172. int get_sample_count() const;
  173. bool is_regression() const;
  174. protected:
  175. virtual float write_results( int k, int k1, int start, int end,
  176. const float* neighbor_responses, const float* dist, CvMat* _results,
  177. CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
  178. virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
  179. float* neighbor_responses, const float** neighbors, float* dist ) const;
  180. int max_k, var_count;
  181. int total;
  182. bool regression;
  183. CvVectors* samples;
  184. };
  185. /****************************************************************************************\
  186. * Support Vector Machines *
  187. \****************************************************************************************/
  188. // SVM training parameters
  189. struct CV_EXPORTS CvSVMParams
  190. {
  191. CvSVMParams();
  192. CvSVMParams( int _svm_type, int _kernel_type,
  193. double _degree, double _gamma, double _coef0,
  194. double _C, double _nu, double _p,
  195. CvMat* _class_weights, CvTermCriteria _term_crit );
  196. int svm_type;
  197. int kernel_type;
  198. double degree; // for poly
  199. double gamma; // for poly/rbf/sigmoid
  200. double coef0; // for poly/sigmoid
  201. double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
  202. double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
  203. double p; // for CV_SVM_EPS_SVR
  204. CvMat* class_weights; // for CV_SVM_C_SVC
  205. CvTermCriteria term_crit; // termination criteria
  206. };
  207. struct CV_EXPORTS CvSVMKernel
  208. {
  209. typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
  210. const float* another, float* results );
  211. CvSVMKernel();
  212. CvSVMKernel( const CvSVMParams* _params, Calc _calc_func );
  213. virtual bool create( const CvSVMParams* _params, Calc _calc_func );
  214. virtual ~CvSVMKernel();
  215. virtual void clear();
  216. virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
  217. const CvSVMParams* params;
  218. Calc calc_func;
  219. virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
  220. const float* another, float* results,
  221. double alpha, double beta );
  222. virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
  223. const float* another, float* results );
  224. virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
  225. const float* another, float* results );
  226. virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
  227. const float* another, float* results );
  228. virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
  229. const float* another, float* results );
  230. };
  231. struct CvSVMKernelRow
  232. {
  233. CvSVMKernelRow* prev;
  234. CvSVMKernelRow* next;
  235. float* data;
  236. };
  237. struct CvSVMSolutionInfo
  238. {
  239. double obj;
  240. double rho;
  241. double upper_bound_p;
  242. double upper_bound_n;
  243. double r; // for Solver_NU
  244. };
  245. class CV_EXPORTS CvSVMSolver
  246. {
  247. public:
  248. typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
  249. typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
  250. typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
  251. CvSVMSolver();
  252. CvSVMSolver( int count, int var_count, const float** samples, char* y,
  253. int alpha_count, double* alpha, double Cp, double Cn,
  254. CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
  255. SelectWorkingSet select_working_set, CalcRho calc_rho );
  256. virtual bool create( int count, int var_count, const float** samples, char* y,
  257. int alpha_count, double* alpha, double Cp, double Cn,
  258. CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
  259. SelectWorkingSet select_working_set, CalcRho calc_rho );
  260. virtual ~CvSVMSolver();
  261. virtual void clear();
  262. virtual bool solve_generic( CvSVMSolutionInfo& si );
  263. virtual bool solve_c_svc( int count, int var_count, const float** samples, char* y,
  264. double Cp, double Cn, CvMemStorage* storage,
  265. CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
  266. virtual bool solve_nu_svc( int count, int var_count, const float** samples, char* y,
  267. CvMemStorage* storage, CvSVMKernel* kernel,
  268. double* alpha, CvSVMSolutionInfo& si );
  269. virtual bool solve_one_class( int count, int var_count, const float** samples,
  270. CvMemStorage* storage, CvSVMKernel* kernel,
  271. double* alpha, CvSVMSolutionInfo& si );
  272. virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
  273. CvMemStorage* storage, CvSVMKernel* kernel,
  274. double* alpha, CvSVMSolutionInfo& si );
  275. virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
  276. CvMemStorage* storage, CvSVMKernel* kernel,
  277. double* alpha, CvSVMSolutionInfo& si );
  278. virtual float* get_row_base( int i, bool* _existed );
  279. virtual float* get_row( int i, float* dst );
  280. int sample_count;
  281. int var_count;
  282. int cache_size;
  283. int cache_line_size;
  284. const float** samples;
  285. const CvSVMParams* params;
  286. CvMemStorage* storage;
  287. CvSVMKernelRow lru_list;
  288. CvSVMKernelRow* rows;
  289. int alpha_count;
  290. double* G;
  291. double* alpha;
  292. // -1 - lower bound, 0 - free, 1 - upper bound
  293. char* alpha_status;
  294. char* y;
  295. double* b;
  296. float* buf[2];
  297. double eps;
  298. int max_iter;
  299. double C[2]; // C[0] == Cn, C[1] == Cp
  300. CvSVMKernel* kernel;
  301. SelectWorkingSet select_working_set_func;
  302. CalcRho calc_rho_func;
  303. GetRow get_row_func;
  304. virtual bool select_working_set( int& i, int& j );
  305. virtual bool select_working_set_nu_svm( int& i, int& j );
  306. virtual void calc_rho( double& rho, double& r );
  307. virtual void calc_rho_nu_svm( double& rho, double& r );
  308. virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
  309. virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
  310. virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
  311. };
  312. struct CvSVMDecisionFunc
  313. {
  314. double rho;
  315. int sv_count;
  316. double* alpha;
  317. int* sv_index;
  318. };
  319. // SVM model
  320. class CV_EXPORTS CvSVM : public CvStatModel
  321. {
  322. public:
  323. // SVM type
  324. enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
  325. // SVM kernel type
  326. enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
  327. CvSVM();
  328. virtual ~CvSVM();
  329. CvSVM( const CvMat* _train_data, const CvMat* _responses,
  330. const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
  331. CvSVMParams _params=CvSVMParams() );
  332. virtual bool train( const CvMat* _train_data, const CvMat* _responses,
  333. const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
  334. CvSVMParams _params=CvSVMParams() );
  335. virtual float predict( const CvMat* _sample ) const;
  336. virtual int get_support_vector_count() const;
  337. virtual const float* get_support_vector(int i) const;
  338. virtual void clear();
  339. virtual void write( CvFileStorage* storage, const char* name );
  340. virtual void read( CvFileStorage* storage, CvFileNode* node );
  341. int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
  342. protected:
  343. virtual bool set_params( const CvSVMParams& _params );
  344. virtual bool train1( int sample_count, int var_count, const float** samples,
  345. const void* _responses, double Cp, double Cn,
  346. CvMemStorage* _storage, double* alpha, double& rho );
  347. virtual void create_kernel();
  348. virtual void create_solver();
  349. virtual void write_params( CvFileStorage* fs );
  350. virtual void read_params( CvFileStorage* fs, CvFileNode* node );
  351. CvSVMParams params;
  352. CvMat* class_labels;
  353. int var_all;
  354. float** sv;
  355. int sv_total;
  356. CvMat* var_idx;
  357. CvMat* class_weights;
  358. CvSVMDecisionFunc* decision_func;
  359. CvMemStorage* storage;
  360. CvSVMSolver* solver;
  361. CvSVMKernel* kernel;
  362. };
  363. /* The function trains SVM model with optimal parameters, obtained by using cross-validation.
  364. The parameters to be estimated should be indicated by setting theirs values to FLT_MAX.
  365. The optimal parameters are saved in <model_params> */
  366. /*CVAPI(CvStatModel*)
  367. cvTrainSVM_CrossValidation( const CvMat* train_data, int tflag,
  368. const CvMat* responses,
  369. CvStatModelParams* model_params,
  370. const CvStatModelParams* cross_valid_params,
  371. const CvMat* comp_idx CV_DEFAULT(0),
  372. const CvMat* sample_idx CV_DEFAULT(0),
  373. const CvParamLattice* degree_lattice CV_DEFAULT(0),
  374. const CvParamLattice* gamma_lattice CV_DEFAULT(0),
  375. const CvParamLattice* coef0_lattice CV_DEFAULT(0),
  376. const CvParamLattice* C_lattice CV_DEFAULT(0),
  377. const CvParamLattice* nu_lattice CV_DEFAULT(0),
  378. const CvParamLattice* p_lattice CV_DEFAULT(0) );*/
  379. /****************************************************************************************\
  380. * Expectation - Maximization *
  381. \****************************************************************************************/
  382. struct CV_EXPORTS CvEMParams
  383. {
  384. CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/),
  385. start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0)
  386. {
  387. term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
  388. }
  389. CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
  390. int _start_step=0/*CvEM::START_AUTO_STEP*/,
  391. CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
  392. const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
  393. nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
  394. probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
  395. {}
  396. int nclusters;
  397. int cov_mat_type;
  398. int start_step;
  399. const CvMat* probs;
  400. const CvMat* weights;
  401. const CvMat* means;
  402. const CvMat** covs;
  403. CvTermCriteria term_crit;
  404. };
  405. class CV_EXPORTS CvEM : public CvStatModel
  406. {
  407. public:
  408. // Type of covariation matrices
  409. enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
  410. // The initial step
  411. enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
  412. CvEM();
  413. CvEM( const CvMat* samples, const CvMat* sample_idx=0,
  414. CvEMParams params=CvEMParams(), CvMat* labels=0 );
  415. virtual ~CvEM();
  416. virtual bool train( const CvMat* samples, const CvMat* sample_idx=0,
  417. CvEMParams params=CvEMParams(), CvMat* labels=0 );
  418. virtual float predict( const CvMat* sample, CvMat* probs ) const;
  419. virtual void clear();
  420. int get_nclusters() const;
  421. const CvMat* get_means() const;
  422. const CvMat** get_covs() const;
  423. const CvMat* get_weights() const;
  424. const CvMat* get_probs() const;
  425. protected:
  426. virtual void set_params( const CvEMParams& params,
  427. const CvVectors& train_data );
  428. virtual void init_em( const CvVectors& train_data );
  429. virtual double run_em( const CvVectors& train_data );
  430. virtual void init_auto( const CvVectors& samples );
  431. virtual void kmeans( const CvVectors& train_data, int nclusters,
  432. CvMat* labels, CvTermCriteria criteria,
  433. const CvMat* means );
  434. CvEMParams params;
  435. double log_likelihood;
  436. CvMat* means;
  437. CvMat** covs;
  438. CvMat* weights;
  439. CvMat* probs;
  440. CvMat* log_weight_div_det;
  441. CvMat* inv_eigen_values;
  442. CvMat** cov_rotate_mats;
  443. };
  444. /****************************************************************************************\
  445. * Decision Tree *
  446. \****************************************************************************************/
  447. struct CvPair32s32f
  448. {
  449. int i;
  450. float val;
  451. };
  452. #define CV_DTREE_CAT_DIR(idx,subset) \
  453. (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
  454. struct CvDTreeSplit
  455. {
  456. int var_idx;
  457. int inversed;
  458. float quality;
  459. CvDTreeSplit* next;
  460. union
  461. {
  462. int subset[2];
  463. struct
  464. {
  465. float c;
  466. int split_point;
  467. }
  468. ord;
  469. };
  470. };
  471. struct CvDTreeNode
  472. {
  473. int class_idx;
  474. int Tn;
  475. double value;
  476. CvDTreeNode* parent;
  477. CvDTreeNode* left;
  478. CvDTreeNode* right;
  479. CvDTreeSplit* split;
  480. int sample_count;
  481. int depth;
  482. int* num_valid;
  483. int offset;
  484. int buf_idx;
  485. double maxlr;
  486. // global pruning data
  487. int complexity;
  488. double alpha;
  489. double node_risk, tree_risk, tree_error;
  490. // cross-validation pruning data
  491. int* cv_Tn;
  492. double* cv_node_risk;
  493. double* cv_node_error;
  494. int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
  495. void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
  496. };
  497. struct CV_EXPORTS CvDTreeParams
  498. {
  499. int max_categories;
  500. int max_depth;
  501. int min_sample_count;
  502. int cv_folds;
  503. bool use_surrogates;
  504. bool use_1se_rule;
  505. bool truncate_pruned_tree;
  506. float regression_accuracy;
  507. const float* priors;
  508. CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
  509. cv_folds(10), use_surrogates(true), use_1se_rule(true),
  510. truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
  511. {}
  512. CvDTreeParams( int _max_depth, int _min_sample_count,
  513. float _regression_accuracy, bool _use_surrogates,
  514. int _max_categories, int _cv_folds,
  515. bool _use_1se_rule, bool _truncate_pruned_tree,
  516. const float* _priors ) :
  517. max_categories(_max_categories), max_depth(_max_depth),
  518. min_sample_count(_min_sample_count), cv_folds (_cv_folds),
  519. use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
  520. truncate_pruned_tree(_truncate_pruned_tree),
  521. regression_accuracy(_regression_accuracy),
  522. priors(_priors)
  523. {}
  524. };
  525. struct CV_EXPORTS CvDTreeTrainData
  526. {
  527. CvDTreeTrainData();
  528. CvDTreeTrainData( const CvMat* _train_data, int _tflag,
  529. const CvMat* _responses, const CvMat* _var_idx=0,
  530. const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  531. const CvMat* _missing_mask=0,
  532. const CvDTreeParams& _params=CvDTreeParams(),
  533. bool _shared=false, bool _add_labels=false );
  534. virtual ~CvDTreeTrainData();
  535. virtual void set_data( const CvMat* _train_data, int _tflag,
  536. const CvMat* _responses, const CvMat* _var_idx=0,
  537. const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  538. const CvMat* _missing_mask=0,
  539. const CvDTreeParams& _params=CvDTreeParams(),
  540. bool _shared=false, bool _add_labels=false,
  541. bool _update_data=false );
  542. virtual void get_vectors( const CvMat* _subsample_idx,
  543. float* values, uchar* missing, float* responses, bool get_class_idx=false );
  544. virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
  545. virtual void write_params( CvFileStorage* fs );
  546. virtual void read_params( CvFileStorage* fs, CvFileNode* node );
  547. // release all the data
  548. virtual void clear();
  549. int get_num_classes() const;
  550. int get_var_type(int vi) const;
  551. int get_work_var_count() const;
  552. virtual int* get_class_labels( CvDTreeNode* n );
  553. virtual float* get_ord_responses( CvDTreeNode* n );
  554. virtual int* get_labels( CvDTreeNode* n );
  555. virtual int* get_cat_var_data( CvDTreeNode* n, int vi );
  556. virtual CvPair32s32f* get_ord_var_data( CvDTreeNode* n, int vi );
  557. virtual int get_child_buf_idx( CvDTreeNode* n );
  558. ////////////////////////////////////
  559. virtual bool set_params( const CvDTreeParams& params );
  560. virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
  561. int storage_idx, int offset );
  562. virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
  563. int split_point, int inversed, float quality );
  564. virtual CvDTreeSplit* new_split_cat( int vi, float quality );
  565. virtual void free_node_data( CvDTreeNode* node );
  566. virtual void free_train_data();
  567. virtual void free_node( CvDTreeNode* node );
  568. int sample_count, var_all, var_count, max_c_count;
  569. int ord_var_count, cat_var_count;
  570. bool have_labels, have_priors;
  571. bool is_classifier;
  572. int buf_count, buf_size;
  573. bool shared;
  574. CvMat* cat_count;
  575. CvMat* cat_ofs;
  576. CvMat* cat_map;
  577. CvMat* counts;
  578. CvMat* buf;
  579. CvMat* direction;
  580. CvMat* split_buf;
  581. CvMat* var_idx;
  582. CvMat* var_type; // i-th element =
  583. // k<0 - ordered
  584. // k>=0 - categorical, see k-th element of cat_* arrays
  585. CvMat* priors;
  586. CvMat* priors_mult;
  587. CvDTreeParams params;
  588. CvMemStorage* tree_storage;
  589. CvMemStorage* temp_storage;
  590. CvDTreeNode* data_root;
  591. CvSet* node_heap;
  592. CvSet* split_heap;
  593. CvSet* cv_heap;
  594. CvSet* nv_heap;
  595. CvRNG rng;
  596. };
  597. class CV_EXPORTS CvDTree : public CvStatModel
  598. {
  599. public:
  600. CvDTree();
  601. virtual ~CvDTree();
  602. virtual bool train( const CvMat* _train_data, int _tflag,
  603. const CvMat* _responses, const CvMat* _var_idx=0,
  604. const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  605. const CvMat* _missing_mask=0,
  606. CvDTreeParams params=CvDTreeParams() );
  607. virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
  608. virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0,
  609. bool preprocessed_input=false ) const;
  610. virtual const CvMat* get_var_importance();
  611. virtual void clear();
  612. virtual void read( CvFileStorage* fs, CvFileNode* node );
  613. virtual void write( CvFileStorage* fs, const char* name );
  614. // special read & write methods for trees in the tree ensembles
  615. virtual void read( CvFileStorage* fs, CvFileNode* node,
  616. CvDTreeTrainData* data );
  617. virtual void write( CvFileStorage* fs );
  618. const CvDTreeNode* get_root() const;
  619. int get_pruned_tree_idx() const;
  620. CvDTreeTrainData* get_data();
  621. protected:
  622. virtual bool do_train( const CvMat* _subsample_idx );
  623. virtual void try_split_node( CvDTreeNode* n );
  624. virtual void split_node_data( CvDTreeNode* n );
  625. virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
  626. virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi );
  627. virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi );
  628. virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi );
  629. virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi );
  630. virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
  631. virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
  632. virtual double calc_node_dir( CvDTreeNode* node );
  633. virtual void complete_node_dir( CvDTreeNode* node );
  634. virtual void cluster_categories( const int* vectors, int vector_count,
  635. int var_count, int* sums, int k, int* cluster_labels );
  636. virtual void calc_node_value( CvDTreeNode* node );
  637. virtual void prune_cv();
  638. virtual double update_tree_rnc( int T, int fold );
  639. virtual int cut_tree( int T, int fold, double min_alpha );
  640. virtual void free_prune_data(bool cut_tree);
  641. virtual void free_tree();
  642. virtual void write_node( CvFileStorage* fs, CvDTreeNode* node );
  643. virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split );
  644. virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
  645. virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
  646. virtual void write_tree_nodes( CvFileStorage* fs );
  647. virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
  648. CvDTreeNode* root;
  649. int pruned_tree_idx;
  650. CvMat* var_importance;
  651. CvDTreeTrainData* data;
  652. };
  653. /****************************************************************************************\
  654. * Random Trees Classifier *
  655. \****************************************************************************************/
  656. class CvRTrees;
  657. class CV_EXPORTS CvForestTree: public CvDTree
  658. {
  659. public:
  660. CvForestTree();
  661. virtual ~CvForestTree();
  662. virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvRTrees* forest );
  663. virtual int get_var_count() const {return data ? data->var_count : 0;}
  664. virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
  665. /* dummy methods to avoid warnings: BEGIN */
  666. virtual bool train( const CvMat* _train_data, int _tflag,
  667. const CvMat* _responses, const CvMat* _var_idx=0,
  668. const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  669. const CvMat* _missing_mask=0,
  670. CvDTreeParams params=CvDTreeParams() );
  671. virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
  672. virtual void read( CvFileStorage* fs, CvFileNode* node );
  673. virtual void read( CvFileStorage* fs, CvFileNode* node,
  674. CvDTreeTrainData* data );
  675. /* dummy methods to avoid warnings: END */
  676. protected:
  677. virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
  678. CvRTrees* forest;
  679. };
  680. struct CV_EXPORTS CvRTParams : public CvDTreeParams
  681. {
  682. //Parameters for the forest
  683. bool calc_var_importance; // true <=> RF processes variable importance
  684. int nactive_vars;
  685. CvTermCriteria term_crit;
  686. CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
  687. calc_var_importance(false), nactive_vars(0)
  688. {
  689. term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
  690. }
  691. CvRTParams( int _max_depth, int _min_sample_count,
  692. float _regression_accuracy, bool _use_surrogates,
  693. int _max_categories, const float* _priors, bool _calc_var_importance,
  694. int _nactive_vars, int max_num_of_trees_in_the_forest,
  695. float forest_accuracy, int termcrit_type ) :
  696. CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
  697. _use_surrogates, _max_categories, 0,
  698. false, false, _priors ),
  699. calc_var_importance(_calc_var_importance),
  700. nactive_vars(_nactive_vars)
  701. {
  702. term_crit = cvTermCriteria(termcrit_type,
  703. max_num_of_trees_in_the_forest, forest_accuracy);
  704. }
  705. };
  706. class CV_EXPORTS CvRTrees : public CvStatModel
  707. {
  708. public:
  709. CvRTrees();
  710. virtual ~CvRTrees();
  711. virtual bool train( const CvMat* _train_data, int _tflag,
  712. const CvMat* _responses, const CvMat* _var_idx=0,
  713. const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  714. const CvMat* _missing_mask=0,
  715. CvRTParams params=CvRTParams() );
  716. virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
  717. virtual void clear();
  718. virtual const CvMat* get_var_importance();
  719. virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
  720. const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
  721. virtual void read( CvFileStorage* fs, CvFileNode* node );
  722. virtual void write( CvFileStorage* fs, const char* name );
  723. CvMat* get_active_var_mask();
  724. CvRNG* get_rng();
  725. int get_tree_count() const;
  726. CvForestTree* get_tree(int i) const;
  727. protected:
  728. bool grow_forest( const CvTermCriteria term_crit );
  729. // array of the trees of the forest
  730. CvForestTree** trees;
  731. CvDTreeTrainData* data;
  732. int ntrees;
  733. int nclasses;
  734. double oob_error;
  735. CvMat* var_importance;
  736. int nsamples;
  737. CvRNG rng;
  738. CvMat* active_var_mask;
  739. };
  740. /****************************************************************************************\
  741. * Boosted tree classifier *
  742. \****************************************************************************************/
  743. struct CV_EXPORTS CvBoostParams : public CvDTreeParams
  744. {
  745. int boost_type;
  746. int weak_count;
  747. int split_criteria;
  748. double weight_trim_rate;
  749. CvBoostParams();
  750. CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
  751. int max_depth, bool use_surrogates, const float* priors );
  752. };
  753. class CvBoost;
  754. class CV_EXPORTS CvBoostTree: public CvDTree
  755. {
  756. public:
  757. CvBoostTree();
  758. virtual ~CvBoostTree();
  759. virtual bool train( CvDTreeTrainData* _train_data,
  760. const CvMat* subsample_idx, CvBoost* ensemble );
  761. virtual void scale( double s );
  762. virtual void read( CvFileStorage* fs, CvFileNode* node,
  763. CvBoost* ensemble, CvDTreeTrainData* _data );
  764. virtual void clear();
  765. /* dummy methods to avoid warnings: BEGIN */
  766. virtual bool train( const CvMat* _train_data, int _tflag,
  767. const CvMat* _responses, const CvMat* _var_idx=0,
  768. const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  769. const CvMat* _missing_mask=0,
  770. CvDTreeParams params=CvDTreeParams() );
  771. virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
  772. virtual void read( CvFileStorage* fs, CvFileNode* node );
  773. virtual void read( CvFileStorage* fs, CvFileNode* node,
  774. CvDTreeTrainData* data );
  775. /* dummy methods to avoid warnings: END */
  776. protected:
  777. virtual void try_split_node( CvDTreeNode* n );
  778. virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
  779. virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
  780. virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi );
  781. virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi );
  782. virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi );
  783. virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi );
  784. virtual void calc_node_value( CvDTreeNode* n );
  785. virtual double calc_node_dir( CvDTreeNode* n );
  786. CvBoost* ensemble;
  787. };
  788. class CV_EXPORTS CvBoost : public CvStatModel
  789. {
  790. public:
  791. // Boosting type
  792. enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
  793. // Splitting criteria
  794. enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
  795. CvBoost();
  796. virtual ~CvBoost();
  797. CvBoost( const CvMat* _train_data, int _tflag,
  798. const CvMat* _responses, const CvMat* _var_idx=0,
  799. const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  800. const CvMat* _missing_mask=0,
  801. CvBoostParams params=CvBoostParams() );
  802. virtual bool train( const CvMat* _train_data, int _tflag,
  803. const CvMat* _responses, const CvMat* _var_idx=0,
  804. const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  805. const CvMat* _missing_mask=0,
  806. CvBoostParams params=CvBoostParams(),
  807. bool update=false );
  808. virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
  809. CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
  810. bool raw_mode=false ) const;
  811. virtual void prune( CvSlice slice );
  812. virtual void clear();
  813. virtual void write( CvFileStorage* storage, const char* name );
  814. virtual void read( CvFileStorage* storage, CvFileNode* node );
  815. CvSeq* get_weak_predictors();
  816. CvMat* get_weights();
  817. CvMat* get_subtree_weights();
  818. CvMat* get_weak_response();
  819. const CvBoostParams& get_params() const;
  820. protected:
  821. virtual bool set_params( const CvBoostParams& _params );
  822. virtual void update_weights( CvBoostTree* tree );
  823. virtual void trim_weights();
  824. virtual void write_params( CvFileStorage* fs );
  825. virtual void read_params( CvFileStorage* fs, CvFileNode* node );
  826. CvDTreeTrainData* data;
  827. CvBoostParams params;
  828. CvSeq* weak;
  829. CvMat* orig_response;
  830. CvMat* sum_response;
  831. CvMat* weak_eval;
  832. CvMat* subsample_mask;
  833. CvMat* weights;
  834. CvMat* subtree_weights;
  835. bool have_subsample;
  836. };
  837. /****************************************************************************************\
  838. * Artificial Neural Networks (ANN) *
  839. \****************************************************************************************/
  840. /////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
  841. struct CV_EXPORTS CvANN_MLP_TrainParams
  842. {
  843. CvANN_MLP_TrainParams();
  844. CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
  845. double param1, double param2=0 );
  846. ~CvANN_MLP_TrainParams();
  847. enum { BACKPROP=0, RPROP=1 };
  848. CvTermCriteria term_crit;
  849. int train_method;
  850. // backpropagation parameters
  851. double bp_dw_scale, bp_moment_scale;
  852. // rprop parameters
  853. double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
  854. };
  855. class CV_EXPORTS CvANN_MLP : public CvStatModel
  856. {
  857. public:
  858. CvANN_MLP();
  859. CvANN_MLP( const CvMat* _layer_sizes,
  860. int _activ_func=SIGMOID_SYM,
  861. double _f_param1=0, double _f_param2=0 );
  862. virtual ~CvANN_MLP();
  863. virtual void create( const CvMat* _layer_sizes,
  864. int _activ_func=SIGMOID_SYM,
  865. double _f_param1=0, double _f_param2=0 );
  866. virtual int train( const CvMat* _inputs, const CvMat* _outputs,
  867. const CvMat* _sample_weights, const CvMat* _sample_idx=0,
  868. CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
  869. int flags=0 );
  870. virtual float predict( const CvMat* _inputs,
  871. CvMat* _outputs ) const;
  872. virtual void clear();
  873. // possible activation functions
  874. enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
  875. // available training flags
  876. enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
  877. virtual void read( CvFileStorage* fs, CvFileNode* node );
  878. virtual void write( CvFileStorage* storage, const char* name );
  879. int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
  880. const CvMat* get_layer_sizes() { return layer_sizes; }
  881. double* get_weights(int layer)
  882. {
  883. return layer_sizes && weights &&
  884. (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
  885. }
  886. protected:
  887. virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
  888. const CvMat* _sample_weights, const CvMat* _sample_idx,
  889. CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
  890. // sequential random backpropagation
  891. virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
  892. // RPROP algorithm
  893. virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
  894. virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
  895. virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
  896. virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
  897. double _f_param1=0, double _f_param2=0 );
  898. virtual void init_weights();
  899. virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
  900. virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
  901. virtual void calc_input_scale( const CvVectors* vecs, int flags );
  902. virtual void calc_output_scale( const CvVectors* vecs, int flags );
  903. virtual void write_params( CvFileStorage* fs );
  904. virtual void read_params( CvFileStorage* fs, CvFileNode* node );
  905. CvMat* layer_sizes;
  906. CvMat* wbuf;
  907. CvMat* sample_weights;
  908. double** weights;
  909. double f_param1, f_param2;
  910. double min_val, max_val, min_val1, max_val1;
  911. int activ_func;
  912. int max_count, max_buf_sz;
  913. CvANN_MLP_TrainParams params;
  914. CvRNG rng;
  915. };
  916. #if 0
  917. /****************************************************************************************\
  918. * Convolutional Neural Network *
  919. \****************************************************************************************/
  920. typedef struct CvCNNLayer CvCNNLayer;
  921. typedef struct CvCNNetwork CvCNNetwork;
  922. #define CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY 1
  923. #define CV_CNN_LEARN_RATE_DECREASE_SQRT_INV 2
  924. #define CV_CNN_LEARN_RATE_DECREASE_LOG_INV 3
  925. #define CV_CNN_GRAD_ESTIM_RANDOM 0
  926. #define CV_CNN_GRAD_ESTIM_BY_WORST_IMG 1
  927. #define ICV_CNN_LAYER 0x55550000
  928. #define ICV_CNN_CONVOLUTION_LAYER 0x00001111
  929. #define ICV_CNN_SUBSAMPLING_LAYER 0x00002222
  930. #define ICV_CNN_FULLCONNECT_LAYER 0x00003333
  931. #define ICV_IS_CNN_LAYER( layer ) \
  932. ( ((layer) != NULL) && ((((CvCNNLayer*)(layer))->flags & CV_MAGIC_MASK)\
  933. == ICV_CNN_LAYER ))
  934. #define ICV_IS_CNN_CONVOLUTION_LAYER( layer ) \
  935. ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags \
  936. & ~CV_MAGIC_MASK) == ICV_CNN_CONVOLUTION_LAYER )
  937. #define ICV_IS_CNN_SUBSAMPLING_LAYER( layer ) \
  938. ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags \
  939. & ~CV_MAGIC_MASK) == ICV_CNN_SUBSAMPLING_LAYER )
  940. #define ICV_IS_CNN_FULLCONNECT_LAYER( layer ) \
  941. ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags \
  942. & ~CV_MAGIC_MASK) == ICV_CNN_FULLCONNECT_LAYER )
  943. typedef void (CV_CDECL *CvCNNLayerForward)
  944. ( CvCNNLayer* layer, const CvMat* input, CvMat* output );
  945. typedef void (CV_CDECL *CvCNNLayerBackward)
  946. ( CvCNNLayer* layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX );
  947. typedef void (CV_CDECL *CvCNNLayerRelease)
  948. (CvCNNLayer** layer);
  949. typedef void (CV_CDECL *CvCNNetworkAddLayer)
  950. (CvCNNetwork* network, CvCNNLayer* layer);
  951. typedef void (CV_CDECL *CvCNNetworkRelease)
  952. (CvCNNetwork** network);
  953. #define CV_CNN_LAYER_FIELDS() \
  954. /* Indicator of the layer's type */ \
  955. int flags; \
  956. \
  957. /* Number of input images */ \
  958. int n_input_planes; \
  959. /* Height of each input image */ \
  960. int input_height; \
  961. /* Width of each input image */ \
  962. int input_width; \
  963. \
  964. /* Number of output images */ \
  965. int n_output_planes; \
  966. /* Height of each output image */ \
  967. int output_height; \
  968. /* Width of each output image */ \
  969. int output_width; \
  970. \
  971. /* Learning rate at the first iteration */ \
  972. float init_learn_rate; \
  973. /* Dynamics of learning rate decreasing */ \
  974. int learn_rate_decrease_type; \
  975. /* Trainable weights of the layer (including bias) */ \
  976. /* i-th row is a set of weights of the i-th output plane */ \
  977. CvMat* weights; \
  978. \
  979. CvCNNLayerForward forward; \
  980. CvCNNLayerBackward backward; \
  981. CvCNNLayerRelease release; \
  982. /* Pointers to the previous and next layers in the network */ \
  983. CvCNNLayer* prev_layer; \
  984. CvCNNLayer* next_layer
  985. typedef struct CvCNNLayer
  986. {
  987. CV_CNN_LAYER_FIELDS();
  988. }CvCNNLayer;
  989. typedef struct CvCNNConvolutionLayer
  990. {
  991. CV_CNN_LAYER_FIELDS();
  992. // Kernel size (height and width) for convolution.
  993. int K;
  994. // connections matrix, (i,j)-th element is 1 iff there is a connection between
  995. // i-th plane of the current layer and j-th plane of the previous layer;
  996. // (i,j)-th element is equal to 0 otherwise
  997. CvMat *connect_mask;
  998. // value of the learning rate for updating weights at the first iteration
  999. }CvCNNConvolutionLayer;
  1000. typedef struct CvCNNSubSamplingLayer
  1001. {
  1002. CV_CNN_LAYER_FIELDS();
  1003. // ratio between the heights (or widths - ratios are supposed to be equal)
  1004. // of the input and output planes
  1005. int sub_samp_scale;
  1006. // amplitude of sigmoid activation function
  1007. float a;
  1008. // scale parameter of sigmoid activation function
  1009. float s;
  1010. // exp2ssumWX = exp(2<s>*(bias+w*(x1+...+x4))), where x1,...x4 are some elements of X
  1011. // - is the vector used in computing of the activation function in backward
  1012. CvMat* exp2ssumWX;
  1013. // (x1+x2+x3+x4), where x1,...x4 are some elements of X
  1014. // - is the vector used in computing of the activation function in backward
  1015. CvMat* sumX;
  1016. }CvCNNSubSamplingLayer;
  1017. // Structure of the last layer.
  1018. typedef struct CvCNNFullConnectLayer
  1019. {
  1020. CV_CNN_LAYER_FIELDS();
  1021. // amplitude of sigmoid activation function
  1022. float a;
  1023. // scale parameter of sigmoid activation function
  1024. float s;
  1025. // exp2ssumWX = exp(2*<s>*(W*X)) - is the vector used in computing of the
  1026. // activation function and it's derivative by the formulae
  1027. // activ.func. = <a>(exp(2<s>WX)-1)/(exp(2<s>WX)+1) == <a> - 2<a>/(<exp2ssumWX> + 1)
  1028. // (activ.func.)' = 4<a><s>exp(2<s>WX)/(exp(2<s>WX)+1)^2
  1029. CvMat* exp2ssumWX;
  1030. }CvCNNFullConnectLayer;
  1031. typedef struct CvCNNetwork
  1032. {
  1033. int n_layers;
  1034. CvCNNLayer* layers;
  1035. CvCNNetworkAddLayer add_layer;
  1036. CvCNNetworkRelease release;
  1037. }CvCNNetwork;
  1038. typedef struct CvCNNStatModel
  1039. {
  1040. CV_STAT_MODEL_FIELDS();
  1041. CvCNNetwork* network;
  1042. // etalons are allocated as rows, the i-th etalon has label cls_labeles[i]
  1043. CvMat* etalons;
  1044. // classes labels
  1045. CvMat* cls_labels;
  1046. }CvCNNStatModel;
  1047. typedef struct CvCNNStatModelParams
  1048. {
  1049. CV_STAT_MODEL_PARAM_FIELDS();
  1050. // network must be created by the functions cvCreateCNNetwork and <add_layer>
  1051. CvCNNetwork* network;
  1052. CvMat* etalons;
  1053. // termination criteria
  1054. int max_iter;
  1055. int start_iter;
  1056. int grad_estim_type;
  1057. }CvCNNStatModelParams;
  1058. CVAPI(CvCNNLayer*) cvCreateCNNConvolutionLayer(
  1059. int n_input_planes, int input_height, int input_width,
  1060. int n_output_planes, int K,
  1061. float init_learn_rate, int learn_rate_decrease_type,
  1062. CvMat* connect_mask CV_DEFAULT(0), CvMat* weights CV_DEFAULT(0) );
  1063. CVAPI(CvCNNLayer*) cvCreateCNNSubSamplingLayer(
  1064. int n_input_planes, int input_height, int input_width,
  1065. int sub_samp_scale, float a, float s,
  1066. float init_learn_rate, int learn_rate_decrease_type, CvMat* weights CV_DEFAULT(0) );
  1067. CVAPI(CvCNNLayer*) cvCreateCNNFullConnectLayer(
  1068. int n_inputs, int n_outputs, float a, float s,
  1069. float init_learn_rate, int learning_type, CvMat* weights CV_DEFAULT(0) );
  1070. CVAPI(CvCNNetwork*) cvCreateCNNetwork( CvCNNLayer* first_layer );
  1071. CVAPI(CvStatModel*) cvTrainCNNClassifier(
  1072. const CvMat* train_data, int tflag,
  1073. const CvMat* responses,
  1074. const CvStatModelParams* params,
  1075. const CvMat* CV_DEFAULT(0),
  1076. const CvMat* sample_idx CV_DEFAULT(0),
  1077. const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0) );
  1078. /****************************************************************************************\
  1079. * Estimate classifiers algorithms *
  1080. \****************************************************************************************/
  1081. typedef const CvMat* (CV_CDECL *CvStatModelEstimateGetMat)
  1082. ( const CvStatModel* estimateModel );
  1083. typedef int (CV_CDECL *CvStatModelEstimateNextStep)
  1084. ( CvStatModel* estimateModel );
  1085. typedef void (CV_CDECL *CvStatModelEstimateCheckClassifier)
  1086. ( CvStatModel* estimateModel,
  1087. const CvStatModel* model,
  1088. const CvMat* features,
  1089. int sample_t_flag,
  1090. const CvMat* responses );
  1091. typedef void (CV_CDECL *CvStatModelEstimateCheckClassifierEasy)
  1092. ( CvStatModel* estimateModel,
  1093. const CvStatModel* model );
  1094. typedef float (CV_CDECL *CvStatModelEstimateGetCurrentResult)
  1095. ( const CvStatModel* estimateModel,
  1096. float* correlation );
  1097. typedef void (CV_CDECL *CvStatModelEstimateReset)
  1098. ( CvStatModel* estimateModel );
  1099. //-------------------------------- Cross-validation --------------------------------------
  1100. #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS() \
  1101. CV_STAT_MODEL_PARAM_FIELDS(); \
  1102. int k_fold; \
  1103. int is_regression; \
  1104. CvRNG* rng
  1105. typedef struct CvCrossValidationParams
  1106. {
  1107. CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS();
  1108. } CvCrossValidationParams;
  1109. #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS() \
  1110. CvStatModelEstimateGetMat getTrainIdxMat; \
  1111. CvStatModelEstimateGetMat getCheckIdxMat; \
  1112. CvStatModelEstimateNextStep nextStep; \
  1113. CvStatModelEstimateCheckClassifier check; \
  1114. CvStatModelEstimateGetCurrentResult getResult; \
  1115. CvStatModelEstimateReset reset; \
  1116. int is_regression; \
  1117. int folds_all; \
  1118. int samples_all; \
  1119. int* sampleIdxAll; \
  1120. int* folds; \
  1121. int max_fold_size; \
  1122. int current_fold; \
  1123. int is_checked; \
  1124. CvMat* sampleIdxTrain; \
  1125. CvMat* sampleIdxEval; \
  1126. CvMat* predict_results; \
  1127. int correct_results; \
  1128. int all_results; \
  1129. double sq_error; \
  1130. double sum_correct; \
  1131. double sum_predict; \
  1132. double sum_cc; \
  1133. double sum_pp; \
  1134. double sum_cp
  1135. typedef struct CvCrossValidationModel
  1136. {
  1137. CV_STAT_MODEL_FIELDS();
  1138. CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS();
  1139. } CvCrossValidationModel;
  1140. CVAPI(CvStatModel*)
  1141. cvCreateCrossValidationEstimateModel
  1142. ( int samples_all,
  1143. const CvStatModelParams* estimateParams CV_DEFAULT(0),
  1144. const CvMat* sampleIdx CV_DEFAULT(0) );
  1145. CVAPI(float)
  1146. cvCrossValidation( const CvMat* trueData,
  1147. int tflag,
  1148. const CvMat* trueClasses,
  1149. CvStatModel* (*createClassifier)( const CvMat*,
  1150. int,
  1151. const CvMat*,
  1152. const CvStatModelParams*,
  1153. const CvMat*,
  1154. const CvMat*,
  1155. const CvMat*,
  1156. const CvMat* ),
  1157. const CvStatModelParams* estimateParams CV_DEFAULT(0),
  1158. const CvStatModelParams* trainParams CV_DEFAULT(0),
  1159. const CvMat* compIdx CV_DEFAULT(0),
  1160. const CvMat* sampleIdx CV_DEFAULT(0),
  1161. CvStatModel** pCrValModel CV_DEFAULT(0),
  1162. const CvMat* typeMask CV_DEFAULT(0),
  1163. const CvMat* missedMeasurementMask CV_DEFAULT(0) );
  1164. #endif
  1165. /****************************************************************************************\
  1166. * Auxilary functions declarations *
  1167. \****************************************************************************************/
  1168. /* Generates <sample> from multivariate normal distribution, where <mean> - is an
  1169. average row vector, <cov> - symmetric covariation matrix */
  1170. CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
  1171. CvRNG* rng CV_DEFAULT(0) );
  1172. /* Generates sample from gaussian mixture distribution */
  1173. CVAPI(void) cvRandGaussMixture( CvMat* means[],
  1174. CvMat* covs[],
  1175. float weights[],
  1176. int clsnum,
  1177. CvMat* sample,
  1178. CvMat* sampClasses CV_DEFAULT(0) );
  1179. #define CV_TS_CONCENTRIC_SPHERES 0
  1180. /* creates test set */
  1181. CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
  1182. int num_samples,
  1183. int num_features,
  1184. CvMat** responses,
  1185. int num_classes, ... );
  1186. /* Aij <- Aji for i > j if lower_to_upper != 0
  1187. for i < j if lower_to_upper = 0 */
  1188. CVAPI(void) cvCompleteSymm( CvMat* matrix, int lower_to_upper );
  1189. #ifdef __cplusplus
  1190. }
  1191. #endif
  1192. #endif /*__ML_H__*/
  1193. /* End of file. */