onlineBoosting.hpp 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. /*M///////////////////////////////////////////////////////////////////////////////////////
  2. //
  3. // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
  4. //
  5. // By downloading, copying, installing or using the software you agree to this license.
  6. // If you do not agree to this license, do not download, install,
  7. // copy or use the software.
  8. //
  9. //
  10. // License Agreement
  11. // For Open Source Computer Vision Library
  12. //
  13. // Copyright (C) 2013, OpenCV Foundation, all rights reserved.
  14. // Third party copyrights are property of their respective owners.
  15. //
  16. // Redistribution and use in source and binary forms, with or without modification,
  17. // are permitted provided that the following conditions are met:
  18. //
  19. // * Redistribution's of source code must retain the above copyright notice,
  20. // this list of conditions and the following disclaimer.
  21. //
  22. // * Redistribution's in binary form must reproduce the above copyright notice,
  23. // this list of conditions and the following disclaimer in the documentation
  24. // and/or other materials provided with the distribution.
  25. //
  26. // * The name of the copyright holders may not be used to endorse or promote products
  27. // derived from this software without specific prior written permission.
  28. //
  29. // This software is provided by the copyright holders and contributors "as is" and
  30. // any express or implied warranties, including, but not limited to, the implied
  31. // warranties of merchantability and fitness for a particular purpose are disclaimed.
  32. // In no event shall the Intel Corporation or contributors be liable for any direct,
  33. // indirect, incidental, special, exemplary, or consequential damages
  34. // (including, but not limited to, procurement of substitute goods or services;
  35. // loss of use, data, or profits; or business interruption) however caused
  36. // and on any theory of liability, whether in contract, strict liability,
  37. // or tort (including negligence or otherwise) arising in any way out of
  38. // the use of this software, even if advised of the possibility of such damage.
  39. //
  40. //M*/
  41. #ifndef __OPENCV_ONLINEBOOSTING_HPP__
  42. #define __OPENCV_ONLINEBOOSTING_HPP__
  43. #include "opencv2/core.hpp"
  44. namespace cv {
  45. namespace detail {
  46. inline namespace tracking {
  47. //! @addtogroup tracking_detail
  48. //! @{
  49. inline namespace online_boosting {
  50. //TODO based on the original implementation
  51. //http://vision.ucsd.edu/~bbabenko/project_miltrack.shtml
  52. class BaseClassifier;
  53. class WeakClassifierHaarFeature;
  54. class EstimatedGaussDistribution;
  55. class ClassifierThreshold;
  56. class Detector;
  57. class StrongClassifierDirectSelection
  58. {
  59. public:
  60. StrongClassifierDirectSelection( int numBaseClf, int numWeakClf, Size patchSz, const Rect& sampleROI, bool useFeatureEx = false, int iterationInit =
  61. 0 );
  62. virtual ~StrongClassifierDirectSelection();
  63. void initBaseClassifier();
  64. bool update( const Mat& image, int target, float importance = 1.0 );
  65. float eval( const Mat& response );
  66. std::vector<int> getSelectedWeakClassifier();
  67. float classifySmooth( const std::vector<Mat>& images, const Rect& sampleROI, int& idx );
  68. int getNumBaseClassifier();
  69. Size getPatchSize() const;
  70. Rect getROI() const;
  71. bool getUseFeatureExchange() const;
  72. int getReplacedClassifier() const;
  73. void replaceWeakClassifier( int idx );
  74. int getSwappedClassifier() const;
  75. private:
  76. //StrongClassifier
  77. int numBaseClassifier;
  78. int numAllWeakClassifier;
  79. int numWeakClassifier;
  80. int iterInit;
  81. BaseClassifier** baseClassifier;
  82. std::vector<float> alpha;
  83. cv::Size patchSize;
  84. bool useFeatureExchange;
  85. //StrongClassifierDirectSelection
  86. std::vector<bool> m_errorMask;
  87. std::vector<float> m_errors;
  88. std::vector<float> m_sumErrors;
  89. Detector* detector;
  90. Rect ROI;
  91. int replacedClassifier;
  92. int swappedClassifier;
  93. };
  94. class BaseClassifier
  95. {
  96. public:
  97. BaseClassifier( int numWeakClassifier, int iterationInit );
  98. BaseClassifier( int numWeakClassifier, int iterationInit, WeakClassifierHaarFeature** weakCls );
  99. WeakClassifierHaarFeature** getReferenceWeakClassifier()
  100. {
  101. return weakClassifier;
  102. }
  103. ;
  104. void trainClassifier( const Mat& image, int target, float importance, std::vector<bool>& errorMask );
  105. int selectBestClassifier( std::vector<bool>& errorMask, float importance, std::vector<float> & errors );
  106. int computeReplaceWeakestClassifier( const std::vector<float> & errors );
  107. void replaceClassifierStatistic( int sourceIndex, int targetIndex );
  108. int getIdxOfNewWeakClassifier()
  109. {
  110. return m_idxOfNewWeakClassifier;
  111. }
  112. ;
  113. int eval( const Mat& image );
  114. virtual ~BaseClassifier();
  115. float getError( int curWeakClassifier );
  116. void getErrors( float* errors );
  117. int getSelectedClassifier() const;
  118. void replaceWeakClassifier( int index );
  119. protected:
  120. void generateRandomClassifier();
  121. WeakClassifierHaarFeature** weakClassifier;
  122. bool m_referenceWeakClassifier;
  123. int m_numWeakClassifier;
  124. int m_selectedClassifier;
  125. int m_idxOfNewWeakClassifier;
  126. std::vector<float> m_wCorrect;
  127. std::vector<float> m_wWrong;
  128. int m_iterationInit;
  129. };
  130. class EstimatedGaussDistribution
  131. {
  132. public:
  133. EstimatedGaussDistribution();
  134. EstimatedGaussDistribution( float P_mean, float R_mean, float P_sigma, float R_sigma );
  135. virtual ~EstimatedGaussDistribution();
  136. void update( float value ); //, float timeConstant = -1.0);
  137. float getMean();
  138. float getSigma();
  139. void setValues( float mean, float sigma );
  140. private:
  141. float m_mean;
  142. float m_sigma;
  143. float m_P_mean;
  144. float m_P_sigma;
  145. float m_R_mean;
  146. float m_R_sigma;
  147. };
  148. class WeakClassifierHaarFeature
  149. {
  150. public:
  151. WeakClassifierHaarFeature();
  152. virtual ~WeakClassifierHaarFeature();
  153. bool update( float value, int target );
  154. int eval( float value );
  155. private:
  156. float sigma;
  157. float mean;
  158. ClassifierThreshold* m_classifier;
  159. void getInitialDistribution( EstimatedGaussDistribution *distribution );
  160. void generateRandomClassifier( EstimatedGaussDistribution* m_posSamples, EstimatedGaussDistribution* m_negSamples );
  161. };
  162. class Detector
  163. {
  164. public:
  165. Detector( StrongClassifierDirectSelection* classifier );
  166. virtual
  167. ~Detector( void );
  168. void
  169. classifySmooth( const std::vector<Mat>& image, float minMargin = 0 );
  170. int
  171. getNumDetections();
  172. float
  173. getConfidence( int patchIdx );
  174. float
  175. getConfidenceOfDetection( int detectionIdx );
  176. float getConfidenceOfBestDetection()
  177. {
  178. return m_maxConfidence;
  179. }
  180. ;
  181. int
  182. getPatchIdxOfBestDetection();
  183. int
  184. getPatchIdxOfDetection( int detectionIdx );
  185. const std::vector<int> &
  186. getIdxDetections() const
  187. {
  188. return m_idxDetections;
  189. }
  190. ;
  191. const std::vector<float> &
  192. getConfidences() const
  193. {
  194. return m_confidences;
  195. }
  196. ;
  197. const cv::Mat &
  198. getConfImageDisplay() const
  199. {
  200. return m_confImageDisplay;
  201. }
  202. private:
  203. void
  204. prepareConfidencesMemory( int numPatches );
  205. void
  206. prepareDetectionsMemory( int numDetections );
  207. StrongClassifierDirectSelection* m_classifier;
  208. std::vector<float> m_confidences;
  209. int m_sizeConfidences;
  210. int m_numDetections;
  211. std::vector<int> m_idxDetections;
  212. int m_sizeDetections;
  213. int m_idxBestDetection;
  214. float m_maxConfidence;
  215. cv::Mat_<float> m_confMatrix;
  216. cv::Mat_<float> m_confMatrixSmooth;
  217. cv::Mat_<unsigned char> m_confImageDisplay;
  218. };
  219. class ClassifierThreshold
  220. {
  221. public:
  222. ClassifierThreshold( EstimatedGaussDistribution* posSamples, EstimatedGaussDistribution* negSamples );
  223. virtual ~ClassifierThreshold();
  224. void update( float value, int target );
  225. int eval( float value );
  226. void* getDistribution( int target );
  227. private:
  228. EstimatedGaussDistribution* m_posSamples;
  229. EstimatedGaussDistribution* m_negSamples;
  230. float m_threshold;
  231. int m_parity;
  232. };
  233. } // namespace
  234. //! @}
  235. }}} // namespace
  236. #endif