onlineBoosting.hpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. /*M///////////////////////////////////////////////////////////////////////////////////////
  2. //
  3. // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
  4. //
  5. // By downloading, copying, installing or using the software you agree to this license.
  6. // If you do not agree to this license, do not download, install,
  7. // copy or use the software.
  8. //
  9. //
  10. // License Agreement
  11. // For Open Source Computer Vision Library
  12. //
  13. // Copyright (C) 2013, OpenCV Foundation, all rights reserved.
  14. // Third party copyrights are property of their respective owners.
  15. //
  16. // Redistribution and use in source and binary forms, with or without modification,
  17. // are permitted provided that the following conditions are met:
  18. //
  19. // * Redistribution's of source code must retain the above copyright notice,
  20. // this list of conditions and the following disclaimer.
  21. //
  22. // * Redistribution's in binary form must reproduce the above copyright notice,
  23. // this list of conditions and the following disclaimer in the documentation
  24. // and/or other materials provided with the distribution.
  25. //
  26. // * The name of the copyright holders may not be used to endorse or promote products
  27. // derived from this software without specific prior written permission.
  28. //
  29. // This software is provided by the copyright holders and contributors "as is" and
  30. // any express or implied warranties, including, but not limited to, the implied
  31. // warranties of merchantability and fitness for a particular purpose are disclaimed.
  32. // In no event shall the Intel Corporation or contributors be liable for any direct,
  33. // indirect, incidental, special, exemplary, or consequential damages
  34. // (including, but not limited to, procurement of substitute goods or services;
  35. // loss of use, data, or profits; or business interruption) however caused
  36. // and on any theory of liability, whether in contract, strict liability,
  37. // or tort (including negligence or otherwise) arising in any way out of
  38. // the use of this software, even if advised of the possibility of such damage.
  39. //
  40. //M*/
  41. #ifndef __OPENCV_ONLINEBOOSTING_HPP__
  42. #define __OPENCV_ONLINEBOOSTING_HPP__
  43. #include "opencv2/core.hpp"
  44. namespace cv
  45. {
  46. //! @addtogroup tracking
  47. //! @{
  48. //TODO based on the original implementation
  49. //http://vision.ucsd.edu/~bbabenko/project_miltrack.shtml
  50. class BaseClassifier;
  51. class WeakClassifierHaarFeature;
  52. class EstimatedGaussDistribution;
  53. class ClassifierThreshold;
  54. class Detector;
  55. class StrongClassifierDirectSelection
  56. {
  57. public:
  58. StrongClassifierDirectSelection( int numBaseClf, int numWeakClf, Size patchSz, const Rect& sampleROI, bool useFeatureEx = false, int iterationInit =
  59. 0 );
  60. virtual ~StrongClassifierDirectSelection();
  61. void initBaseClassifier();
  62. bool update( const Mat& image, int target, float importance = 1.0 );
  63. float eval( const Mat& response );
  64. std::vector<int> getSelectedWeakClassifier();
  65. float classifySmooth( const std::vector<Mat>& images, const Rect& sampleROI, int& idx );
  66. int getNumBaseClassifier();
  67. Size getPatchSize() const;
  68. Rect getROI() const;
  69. bool getUseFeatureExchange() const;
  70. int getReplacedClassifier() const;
  71. void replaceWeakClassifier( int idx );
  72. int getSwappedClassifier() const;
  73. private:
  74. //StrongClassifier
  75. int numBaseClassifier;
  76. int numAllWeakClassifier;
  77. int numWeakClassifier;
  78. int iterInit;
  79. BaseClassifier** baseClassifier;
  80. std::vector<float> alpha;
  81. cv::Size patchSize;
  82. bool useFeatureExchange;
  83. //StrongClassifierDirectSelection
  84. std::vector<bool> m_errorMask;
  85. std::vector<float> m_errors;
  86. std::vector<float> m_sumErrors;
  87. Detector* detector;
  88. Rect ROI;
  89. int replacedClassifier;
  90. int swappedClassifier;
  91. };
  92. class BaseClassifier
  93. {
  94. public:
  95. BaseClassifier( int numWeakClassifier, int iterationInit );
  96. BaseClassifier( int numWeakClassifier, int iterationInit, WeakClassifierHaarFeature** weakCls );
  97. WeakClassifierHaarFeature** getReferenceWeakClassifier()
  98. {
  99. return weakClassifier;
  100. }
  101. ;
  102. void trainClassifier( const Mat& image, int target, float importance, std::vector<bool>& errorMask );
  103. int selectBestClassifier( std::vector<bool>& errorMask, float importance, std::vector<float> & errors );
  104. int computeReplaceWeakestClassifier( const std::vector<float> & errors );
  105. void replaceClassifierStatistic( int sourceIndex, int targetIndex );
  106. int getIdxOfNewWeakClassifier()
  107. {
  108. return m_idxOfNewWeakClassifier;
  109. }
  110. ;
  111. int eval( const Mat& image );
  112. virtual ~BaseClassifier();
  113. float getError( int curWeakClassifier );
  114. void getErrors( float* errors );
  115. int getSelectedClassifier() const;
  116. void replaceWeakClassifier( int index );
  117. protected:
  118. void generateRandomClassifier();
  119. WeakClassifierHaarFeature** weakClassifier;
  120. bool m_referenceWeakClassifier;
  121. int m_numWeakClassifier;
  122. int m_selectedClassifier;
  123. int m_idxOfNewWeakClassifier;
  124. std::vector<float> m_wCorrect;
  125. std::vector<float> m_wWrong;
  126. int m_iterationInit;
  127. };
  128. class EstimatedGaussDistribution
  129. {
  130. public:
  131. EstimatedGaussDistribution();
  132. EstimatedGaussDistribution( float P_mean, float R_mean, float P_sigma, float R_sigma );
  133. virtual ~EstimatedGaussDistribution();
  134. void update( float value ); //, float timeConstant = -1.0);
  135. float getMean();
  136. float getSigma();
  137. void setValues( float mean, float sigma );
  138. private:
  139. float m_mean;
  140. float m_sigma;
  141. float m_P_mean;
  142. float m_P_sigma;
  143. float m_R_mean;
  144. float m_R_sigma;
  145. };
  146. class WeakClassifierHaarFeature
  147. {
  148. public:
  149. WeakClassifierHaarFeature();
  150. virtual ~WeakClassifierHaarFeature();
  151. bool update( float value, int target );
  152. int eval( float value );
  153. private:
  154. float sigma;
  155. float mean;
  156. ClassifierThreshold* m_classifier;
  157. void getInitialDistribution( EstimatedGaussDistribution *distribution );
  158. void generateRandomClassifier( EstimatedGaussDistribution* m_posSamples, EstimatedGaussDistribution* m_negSamples );
  159. };
  160. class Detector
  161. {
  162. public:
  163. Detector( StrongClassifierDirectSelection* classifier );
  164. virtual
  165. ~Detector( void );
  166. void
  167. classifySmooth( const std::vector<Mat>& image, float minMargin = 0 );
  168. int
  169. getNumDetections();
  170. float
  171. getConfidence( int patchIdx );
  172. float
  173. getConfidenceOfDetection( int detectionIdx );
  174. float getConfidenceOfBestDetection()
  175. {
  176. return m_maxConfidence;
  177. }
  178. ;
  179. int
  180. getPatchIdxOfBestDetection();
  181. int
  182. getPatchIdxOfDetection( int detectionIdx );
  183. const std::vector<int> &
  184. getIdxDetections() const
  185. {
  186. return m_idxDetections;
  187. }
  188. ;
  189. const std::vector<float> &
  190. getConfidences() const
  191. {
  192. return m_confidences;
  193. }
  194. ;
  195. const cv::Mat &
  196. getConfImageDisplay() const
  197. {
  198. return m_confImageDisplay;
  199. }
  200. private:
  201. void
  202. prepareConfidencesMemory( int numPatches );
  203. void
  204. prepareDetectionsMemory( int numDetections );
  205. StrongClassifierDirectSelection* m_classifier;
  206. std::vector<float> m_confidences;
  207. int m_sizeConfidences;
  208. int m_numDetections;
  209. std::vector<int> m_idxDetections;
  210. int m_sizeDetections;
  211. int m_idxBestDetection;
  212. float m_maxConfidence;
  213. cv::Mat_<float> m_confMatrix;
  214. cv::Mat_<float> m_confMatrixSmooth;
  215. cv::Mat_<unsigned char> m_confImageDisplay;
  216. };
  217. class ClassifierThreshold
  218. {
  219. public:
  220. ClassifierThreshold( EstimatedGaussDistribution* posSamples, EstimatedGaussDistribution* negSamples );
  221. virtual ~ClassifierThreshold();
  222. void update( float value, int target );
  223. int eval( float value );
  224. void* getDistribution( int target );
  225. private:
  226. EstimatedGaussDistribution* m_posSamples;
  227. EstimatedGaussDistribution* m_negSamples;
  228. float m_threshold;
  229. int m_parity;
  230. };
  231. //! @}
  232. } /* namespace cv */
  233. #endif