ClassificationModel.cs 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. #if !UNITY_WSA_10_0
  2. using OpenCVForUnity.CoreModule;
  3. using OpenCVForUnity.UtilsModule;
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Runtime.InteropServices;
  7. namespace OpenCVForUnity.DnnModule
  8. {
  9. // C++: class ClassificationModel
  10. /**
  11. * This class represents high-level API for classification models.
  12. *
  13. * ClassificationModel allows to set params for preprocessing input image.
  14. * ClassificationModel creates net from file with trained weights and config,
  15. * sets preprocessing input, runs forward pass and return top-1 prediction.
  16. */
  17. public class ClassificationModel : Model
  18. {
  19. protected override void Dispose(bool disposing)
  20. {
  21. try
  22. {
  23. if (disposing)
  24. {
  25. }
  26. if (IsEnabledDispose)
  27. {
  28. if (nativeObj != IntPtr.Zero)
  29. dnn_ClassificationModel_delete(nativeObj);
  30. nativeObj = IntPtr.Zero;
  31. }
  32. }
  33. finally
  34. {
  35. base.Dispose(disposing);
  36. }
  37. }
  38. protected internal ClassificationModel(IntPtr addr) : base(addr) { }
  39. // internal usage only
  40. public static new ClassificationModel __fromPtr__(IntPtr addr) { return new ClassificationModel(addr); }
  41. //
  42. // C++: cv::dnn::ClassificationModel::ClassificationModel(String model, String config = "")
  43. //
  44. /**
  45. * Create classification model from network represented in one of the supported formats.
  46. * An order of {code model} and {code config} arguments does not matter.
  47. * param model Binary file contains trained weights.
  48. * param config Text file contains network configuration.
  49. */
  50. public ClassificationModel(string model, string config) :
  51. base(DisposableObject.ThrowIfNullIntPtr(dnn_ClassificationModel_ClassificationModel_10(model, config)))
  52. {
  53. }
  54. /**
  55. * Create classification model from network represented in one of the supported formats.
  56. * An order of {code model} and {code config} arguments does not matter.
  57. * param model Binary file contains trained weights.
  58. */
  59. public ClassificationModel(string model) :
  60. base(DisposableObject.ThrowIfNullIntPtr(dnn_ClassificationModel_ClassificationModel_11(model)))
  61. {
  62. }
  63. //
  64. // C++: cv::dnn::ClassificationModel::ClassificationModel(Net network)
  65. //
  66. /**
  67. * Create model from deep learning network.
  68. * param network Net object.
  69. */
  70. public ClassificationModel(Net network) :
  71. base(DisposableObject.ThrowIfNullIntPtr(dnn_ClassificationModel_ClassificationModel_12(network.nativeObj)))
  72. {
  73. }
  74. //
  75. // C++: ClassificationModel cv::dnn::ClassificationModel::setEnableSoftmaxPostProcessing(bool enable)
  76. //
  77. /**
  78. * Set enable/disable softmax post processing option.
  79. *
  80. * If this option is true, softmax is applied after forward inference within the classify() function
  81. * to convert the confidences range to [0.0-1.0].
  82. * This function allows you to toggle this behavior.
  83. * Please turn true when not contain softmax layer in model.
  84. * param enable Set enable softmax post processing within the classify() function.
  85. * return automatically generated
  86. */
  87. public ClassificationModel setEnableSoftmaxPostProcessing(bool enable)
  88. {
  89. ThrowIfDisposed();
  90. return new ClassificationModel(DisposableObject.ThrowIfNullIntPtr(dnn_ClassificationModel_setEnableSoftmaxPostProcessing_10(nativeObj, enable)));
  91. }
  92. //
  93. // C++: bool cv::dnn::ClassificationModel::getEnableSoftmaxPostProcessing()
  94. //
  95. /**
  96. * Get enable/disable softmax post processing option.
  97. *
  98. * This option defaults to false, softmax post processing is not applied within the classify() function.
  99. * return automatically generated
  100. */
  101. public bool getEnableSoftmaxPostProcessing()
  102. {
  103. ThrowIfDisposed();
  104. return dnn_ClassificationModel_getEnableSoftmaxPostProcessing_10(nativeObj);
  105. }
  106. //
  107. // C++: void cv::dnn::ClassificationModel::classify(Mat frame, int& classId, float& conf)
  108. //
  109. public void classify(Mat frame, int[] classId, float[] conf)
  110. {
  111. ThrowIfDisposed();
  112. if (frame != null) frame.ThrowIfDisposed();
  113. double[] classId_out = new double[1];
  114. double[] conf_out = new double[1];
  115. dnn_ClassificationModel_classify_10(nativeObj, frame.nativeObj, classId_out, conf_out);
  116. if (classId != null) classId[0] = (int)classId_out[0];
  117. if (conf != null) conf[0] = (float)conf_out[0];
  118. }
  119. #if (UNITY_IOS || UNITY_WEBGL) && !UNITY_EDITOR
  120. const string LIBNAME = "__Internal";
  121. #else
  122. const string LIBNAME = "opencvforunity";
  123. #endif
  124. // C++: cv::dnn::ClassificationModel::ClassificationModel(String model, String config = "")
  125. [DllImport(LIBNAME)]
  126. private static extern IntPtr dnn_ClassificationModel_ClassificationModel_10(string model, string config);
  127. [DllImport(LIBNAME)]
  128. private static extern IntPtr dnn_ClassificationModel_ClassificationModel_11(string model);
  129. // C++: cv::dnn::ClassificationModel::ClassificationModel(Net network)
  130. [DllImport(LIBNAME)]
  131. private static extern IntPtr dnn_ClassificationModel_ClassificationModel_12(IntPtr network_nativeObj);
  132. // C++: ClassificationModel cv::dnn::ClassificationModel::setEnableSoftmaxPostProcessing(bool enable)
  133. [DllImport(LIBNAME)]
  134. private static extern IntPtr dnn_ClassificationModel_setEnableSoftmaxPostProcessing_10(IntPtr nativeObj, [MarshalAs(UnmanagedType.U1)] bool enable);
  135. // C++: bool cv::dnn::ClassificationModel::getEnableSoftmaxPostProcessing()
  136. [DllImport(LIBNAME)]
  137. [return: MarshalAs(UnmanagedType.U1)]
  138. private static extern bool dnn_ClassificationModel_getEnableSoftmaxPostProcessing_10(IntPtr nativeObj);
  139. // C++: void cv::dnn::ClassificationModel::classify(Mat frame, int& classId, float& conf)
  140. [DllImport(LIBNAME)]
  141. private static extern void dnn_ClassificationModel_classify_10(IntPtr nativeObj, IntPtr frame_nativeObj, double[] classId_out, double[] conf_out);
  142. // native support for java finalize()
  143. [DllImport(LIBNAME)]
  144. private static extern void dnn_ClassificationModel_delete(IntPtr nativeObj);
  145. }
  146. }
  147. #endif