SuaKITTrainer.h 74 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327
  1. #pragma once
  2. /*
  3. SuaKIT runtime API 2.2.6.3
  4. Copyright (c) 2018 SuaLab, All right reserved.
  5. */
  6. /**
  7. * @file SuaKITTrainer.h
  8. */
  9. #include "SuaKIT.h"
  10. namespace SuaKIT {
  11. namespace API {
  12. /**
  13. * @brief This is a configuration of the ClassificationTrainer.
  14. * @details It is used as classification train setting parameter for ClassificationTrainer.
  15. * @author ⓒSualab SuaKIT Team
  16. */
  17. class SUAKIT_API ClassificationTrainConfig {
  18. public:
  19. /**
  20. * @brief This is a constructor of ClassificationTrainConfig.
  21. * @param numberOfClasses The number of classes to train.
  22. * @param validationRatio Ratio of validation set. If user does not provide validation set, validation set is split from train set by this ratio.
  23. This ratio is calculated as ((validationRatio/100)*100)% of train set.
  24. This argument is meaningless when the user selects and uses a validation set as a separate ClassificationTrainData object.
  25. * @param epochs In training neural network, one epoch means one sweep of the full training set. This is converted to an iteration by the number of train sets.
  26. * @param modelCapacity The base (pre-defined) model capacity for executing ClassificationTrainer. {0, 1, 2, 3}.
  27. - 0 means small model capacity.
  28. - 1 means normal model capacity.
  29. - 2 means large model capacity.
  30. - 3 means extra large model capacity.
  31. * @param dataRatio This is a weight parameter for each class. If you want to give a larger weight to a specific class, you can increase this array's value corresponding index to the location of the specific class.
  32. This value will be applied at batch-sampling-time. This value affects the ratio between each class sample in training batch.
  33. * @param imageHeight Height of image for train. If the learning image size is variable, set it to the largest height value in the training image.
  34. * @param imageWidth Width of images for train. If the learning image size is variable, set it to the largest width value in the training image.
  35. * @param imageChannel The number of channels of images for train.
  36. * @param augmentationConfig Augmentation parameters for train.
  37. * @param inputDataType This parameter is used to configure input type of neural network. Set it to one of {InputDataType::SINGLE, InputDataType::COMPARISON, InputDataType::MULTIMG}.
  38. In case of multi-image classification, we provide only 1 modelCapacity: {NORMAL}. User needs to select modelCapacity parameter NORMAL capacity.
  39. * @param multiImgCount The number of images for classification trainer, depends on inputDataType parameter.
  40. - In case of InputDataType::SINGLE, multiImgCount parameter needs to be initialized with 1.
  41. - In case of InputDataType::COMPARISON, multiImgCount parameter needs to be initialized with 2.
  42. - In case of InputDataType::MULTIMG, multiImgCount parameter needs to be initialized with multiImgCount >= 2.
  43. * @param imageDepth Depth of images for train.
  44. * @param numModelSelection Save the top N (numModelSelection) models based on the best loss.
  45. If the number of generated models is less than this value, only the currently created models are saved.
  46. * @param minimumEpoch The minimum number of epochs to save the model.
  47. * @param patience The number of epochs to wait before early stop if no progress on the validation set. The patience is often set somewhere between 10 and 100 (10 or 20 is more common), but it really depends on your dataset.
  48. - In case of ClassificationTrainer, the learning rate decay is performed once.
  49. */
  50. __stdcall ClassificationTrainConfig(
  51. SuaKIT_Int32 numberOfClasses,
  52. SuaKIT_Int32 validationRatio,
  53. SuaKIT_Int32 epochs,
  54. SuaKIT_Int32 modelCapacity,
  55. const Int32Array& dataRatio,
  56. SuaKIT_Int32 imageHeight,
  57. SuaKIT_Int32 imageWidth,
  58. SuaKIT_Int32 imageChannel,
  59. const AugmentationConfig& augmentationConfig = AugmentationConfig(),
  60. InputDataType inputDataType = InputDataType::SINGLE,
  61. SuaKIT_Int32 multiImgCount = 1,
  62. DepthType imageDepth = DepthType::_8U,
  63. SuaKIT_Int32 numModelSelection = 1,
  64. float minimumEpoch = 0.0f,
  65. SuaKIT_Int64 patience = 9223372036854775807ll
  66. );
  67. __stdcall ~ClassificationTrainConfig();
  68. void* _GetInternal() const { return internal; }
  69. /**
  70. * @brief Methods to set the class name
  71. * @param classIdx Class Index
  72. * @param className String to set as the name of class corresponding to Class Index.
  73. On Windows, it should be mbcs encoding, and
  74. On Linux, this should be utf8 encoding.
  75. * @return Returns Status.
  76. */
  77. Status __stdcall SetClassName(SuaKIT_Int32 classIdx, const char * className);
  78. /**
  79. * @return Returns Status.
  80. */
  81. Status __stdcall GetStatus() const;
  82. private:
  83. ClassificationTrainConfig(const ClassificationTrainConfig& src); //=delete;
  84. ClassificationTrainConfig& operator=(const ClassificationTrainConfig& rhs); //=delete;
  85. void *internal;
  86. Status m_status;
  87. };
  88. /**
  89. * @brief This is a configuration of the PostTrainer.
  90. * @details It is used as post train setting parameter for PostTrainer.
  91. * @author ⓒSualab SuaKIT Team
  92. */
  93. class SUAKIT_API PostTrainConfig {
  94. public:
  95. /**
  96. * @brief This is a constructor of PostTrainConfig.
  97. * @param numberOfClasses The number of classes to train. You should include only one new class.
  98. This value should be the same with (number of old class (trained classification model) + number of new class)
  99. * @param validationRatio Ratio of validation set. If user does not provide validation set, validation set is split from train set by this ratio.
  100. This ratio is calculated as ((validationRatio/100)*100)% of train set.
  101. This argument is meaningless when the user selects and uses a validation set as a separate PostTrainData object.
  102. * @param epochs In training neural network, one epoch means one sweep of the full training set. This is converted to an iteration by the number of train sets.
  103. Internally, several model of Post trainer would be trained referencing this value.
  104. * @param dataRatio This is a weight parameter for each class. If you want to give a larger weight to a specific class, you can increase this array's value corresponding index to the location of the specific class.
  105. This value will be applied at batch-sampling-time. This value affects the ratio between each class sample in training batch.
  106. * @param imageHeight Height of image for train. If the learning image size is variable, set it to the largest height value in the training image.
  107. * @param imageWidth Width of images for train. If the learning image size is variable, set it to the largest width value in the training image.
  108. * @param imageChannel The number of channels of images for train.
  109. * @param augmentationConfig Augmentation parameters for train.
  110. * @param imageDepth Depth of images for train.
  111. * @param minimumEpoch The minimum number of epochs to save the model.
  112. * @param patience The number of epochs to wait before early stop if no progress on the validation set. The patience is often set somewhere between 10 and 100 (10 or 20 is more common), but it really depends on your dataset.
  113. - In case of PostTrainer, the learning rate decay is performed once.
  114. */
  115. __stdcall PostTrainConfig(
  116. SuaKIT_Int32 numberOfClasses,
  117. SuaKIT_Int32 validationRatio,
  118. SuaKIT_Int32 epochs,
  119. const Int32Array& dataRatio,
  120. SuaKIT_Int32 imageHeight,
  121. SuaKIT_Int32 imageWidth,
  122. SuaKIT_Int32 imageChannel,
  123. const AugmentationConfig& augmentationConfig = AugmentationConfig(),
  124. DepthType imageDepth = DepthType::_8U,
  125. float minimumEpoch = 0.0f,
  126. SuaKIT_Int64 patience = 9223372036854775807ll
  127. );
  128. __stdcall ~PostTrainConfig();
  129. void* _GetInternal() const { return internal; }
  130. /**
  131. * @brief Methods to set the class name
  132. * @param classIdx Class Index
  133. * @param className String to set as the name of class corresponding to Class Index.
  134. On Windows, it should be mbcs encoding, and
  135. On Linux, this should be utf8 encoding.
  136. * @return Returns Status.
  137. */
  138. Status __stdcall SetClassName(SuaKIT_Int32 classIdx, const char * className);
  139. /**
  140. * @return Returns Status.
  141. */
  142. Status __stdcall GetStatus() const;
  143. private:
  144. PostTrainConfig(const PostTrainConfig& src); //=delete;
  145. PostTrainConfig& operator=(const PostTrainConfig& rhs); //=delete;
  146. void *internal;
  147. Status m_status;
  148. };
  149. /**
  150. * @brief This is a configuration of the SegmentationTrainer.
  151. * @details It is used as segmentation train setting parameter for SegmentationTrainer.
  152. * @author ⓒSualab SuaKIT Team
  153. */
  154. class SUAKIT_API SegmentationTrainConfig {
  155. public:
  156. /**
  157. * @brief This is a constructor of the SegmentationTrainConfig.
  158. * @param numberOfClasses The number of classes to train. In segmentation training case, this parameter should be calculated except unlabeled class.
  159. * @param validationRatio Ratio of validation set. If user does not provide validation set, validation set is split from train set by this ratio.
  160. This ratio is calculated as ((validationRatio/100)*100)% of train set.
  161. * @param epochs In training neural network, one epoch means one sweep of the full training set. This is converted to an iteration by the number of train sets.
  162. * @param modelCapacity The base (pre-defined) model capacity for executing SegmentationTrainer. {0, 1, 2, 3}.
  163. - 0 means small model capacity.
  164. - 1 means normal model capacity.
  165. - 2 means large model capacity.
  166. - 3 means extra large model capacity.
  167. * @param dataRatio This is a weight parameter for each class. If you want to give a larger weight to specific class, you can increase this array's value corresponding index to the location of the specific class.
  168. This value will be applied at batch-sampling-time. This value affects the ratio between each class sample in training batch.
  169. * @param unlabeledRatio If you also want to adjust ratio of unlabeled patch, you can append unlabeledRatio to original dataRatio by this parameter.
  170. * @param strideRatio Training patch data is sampled from original image by sliding windows manner. strideRatio value means stride of shifting.
  171. If you assign 50 into this parameter, then 50% size corresponding to patch size (patchSize * 0.5) would be shifting scale.
  172. Internally, this value must be in the range [10, 100].
  173. * @param patchSize The size of training patch data. This value should be multiple of 4, and must be in the range [128, max].
  174. You can set arbitrary value satisfying those condition, but WE HIGHLY RECOMMEND 128 or 256 for this parameter.
  175. * @param imageChannel The number of channels of images for train.
  176. * @param augmentationConfig Augmentation parameter for train.
  177. * @param inputDataType This parameter is used to configure input type of neural network. Set it to one of {InputDataType::SINGLE, InputDataType::COMPARISON, InputDataType::MULTIMG}.
  178. In case of multi-image segmentation, we provide only 2 modelCapacity: {NORMAL, LARGE}. User needs to select modelCapacity parameter NORMAL or LARGE capacity.
  179. * @param multiImgCount The number of images for segmentation trainer, depends on inputDataType parameter.
  180. - In case of InputDataType::SINGLE, multiImgCount parameter needs to be initialized with 1.
  181. - In case of InputDataType::COMPARISON, multiImgCount parameter needs to be initialized with 2.
  182. - In case of InputDataType::MULTIMG, multiImgCount parameter needs to be initialized with multiImgCount >= 2.
  183. * @param imageDepth Depth of images for train.
  184. * @param numModelSelection Save the top N (numModelSelection) models based on the best loss.
  185. If the number of generated models is less than this value, only the currently created models are saved.
  186. * @param minimumEpoch The minimum number of epochs to save the model.
  187. * @param patience The number of epochs to wait before early stop if no progress on the validation set. The patience is often set somewhere between 10 and 100 (10 or 20 is more common), but it really depends on your dataset.
  188. - In case of SegmentationTrainer, the learning rate decay is performed twice.
  189. * @param segmentationModelType This parameter is used to configure model type of segmentation. Set it to one of {ModelType::SEGMENTATION, ModelType::SEGMENTATION_SENSITIVE, ModelType::SEGMENTATION_CONTEXTUAL}.
  190. */
  191. __stdcall SegmentationTrainConfig(
  192. SuaKIT_Int32 numberOfClasses,
  193. SuaKIT_Int32 validationRatio,
  194. SuaKIT_Int32 epochs,
  195. SuaKIT_Int32 modelCapacity,
  196. const Int32Array& dataRatio,
  197. SuaKIT_Int32 unlabeledRatio,
  198. SuaKIT_Int32 strideRatio,
  199. SuaKIT_Int32 patchSize,
  200. SuaKIT_Int32 imageChannel,
  201. const AugmentationConfig& augmentationConfig = AugmentationConfig(),
  202. InputDataType inputDataType = InputDataType::SINGLE,
  203. SuaKIT_Int32 multiImgCount = 1,
  204. DepthType imageDepth = DepthType::_8U,
  205. SuaKIT_Int32 numModelSelection = 1,
  206. float minimumEpoch = 0.0f,
  207. SuaKIT_Int64 patience = 9223372036854775807ll,
  208. ModelType segmentationModelType = ModelType::SEGMENTATION
  209. );
  210. __stdcall ~SegmentationTrainConfig();
  211. void* _GetInternal() const { return internal; }
  212. /**
  213. * @brief Methods to set the class name
  214. * @param classIdx Class Index
  215. * @param className String to set as the name of class corresponding to Class Index.
  216. On Windows, it should be mbcs encoding, and
  217. On Linux, this should be utf8 encoding.
  218. * @return Returns Status.
  219. */
  220. Status __stdcall SetClassName(SuaKIT_Int32 classIdx, const char * className);
  221. /**
  222. * @return Returns Status.
  223. */
  224. Status __stdcall GetStatus() const;
  225. private:
  226. SegmentationTrainConfig(const SegmentationTrainConfig& src); //=delete;
  227. SegmentationTrainConfig& operator=(const SegmentationTrainConfig& rhs); //=delete;
  228. void *internal;
  229. Status m_status;
  230. };
  231. /**
  232. * @brief This is a configuration of the OneClassSegmentationTrainer.
  233. * @details It is used as OneClassSegmentation train setting parameter for OneClassSegmentationTrainer.
  234. * @author ⓒSualab SuaKIT Team
  235. */
  236. class SUAKIT_API OneClassSegmentationTrainConfig {
  237. public:
  238. /**
  239. * @brief This is a constructor of the OneClassSegmentationTrainConfig.
  240. * @param validationRatio Ratio of validation set. If user does not provide validation set, validation set is split from train set by this ratio.
  241. This ratio is calculated as ((validationRatio/100)*100)% of train set.
  242. * @param epochs In training neural network, one epoch means one sweep of the full training set. This is converted to an iteration by the number of train sets.
  243. * @param modelCapacity Internal model capacity for executing OneClassSegmentationTrainer. {1, 2}.
  244. - 1 means normal model capacity.
  245. - 2 means large model capacity.
  246. In case of one class segmentation, we provide only 2 modelCapacity: {NORMAL, LARGE}. User needs to select modelCapacity parameter NORMAL or LARGE capacity.
  247. * @param strideRatio Training patch data is sampled from original image by sliding windows manner. strideRatio value means stride of shifting.
  248. If you assign 50 into this parameter, then 50% size corresponding to patch size (patchSize * 0.5) would be shifting scale.
  249. Internally, this value must be in the range [10, 100].
  250. * @param patchSize The size of training patch data. This value should be multiple of 4, and must be in the range [64, max].
  251. You can set arbitrary value satisfying those condition, but WE HIGHLY RECOMMEND 64 or 128 or 256 for this parameter.
  252. * @param imageChannel The number of channels of images for train.
  253. * @param augmentationConfig Augmentation parameter for train.
  254. Currently, augmentation for OneClassSegmentation is not supported. Just put default constructor.
  255. * @param inputDataType This parameter is used to configure input type of neural network. Set it to one of {InputDataType::SINGLE, InputDataType::COMPARISON, InputDataType::MULTIMG}.
  256. Currently, OneClassSegmentation only support SINGLE type.
  257. * @param multiImgCount The number of images for segmentation trainer, depends on inputDataType parameter.
  258. - In case of InputDataType::SINGLE, multiImgCount parameter needs to be initialized with 1.
  259. - In case of InputDataType::COMPARISON, multiImgCount parameter needs to be initialized with 2.
  260. - In case of InputDataType::MULTIMG, multiImgCount parameter needs to be initialized with multiImgCount >= 2.
  261. Currently, OneClassSegmentation only support SINGLE type. (multiImgCount should be 1)
  262. * @param imageDepth Depth of images for train.
  263. * @param numModelSelection Save the top N (numModelSelection) models based on the best loss.
  264. If the number of generated models is less than this value, only the currently created models are saved.
  265. * @param minimumEpoch The minimum number of epochs to save the model.
  266. * @param patience The number of epochs to wait before early stop if no progress on the validation set. The patience is often set somewhere between 10 and 100 (10 or 20 is more common), but it really depends on your dataset.
  267. - In case of OneClassSegmentationTrainer, the learning rate decay is not performed.
  268. */
  269. __stdcall OneClassSegmentationTrainConfig(
  270. SuaKIT_Int32 validationRatio,
  271. SuaKIT_Int32 epochs,
  272. SuaKIT_Int32 modelCapacity,
  273. SuaKIT_Int32 strideRatio,
  274. SuaKIT_Int32 patchSize,
  275. SuaKIT_Int32 imageChannel,
  276. const AugmentationConfig& augmentationConfig = AugmentationConfig(),
  277. InputDataType inputDataType = InputDataType::SINGLE,
  278. SuaKIT_Int32 multiImgCount = 1,
  279. DepthType imageDepth = DepthType::_8U,
  280. SuaKIT_Int32 numModelSelection = 1,
  281. float minimumEpoch = 0.0f,
  282. SuaKIT_Int64 patience = 9223372036854775807ll
  283. );
  284. __stdcall ~OneClassSegmentationTrainConfig();
  285. void* _GetInternal() const { return internal; }
  286. /**
  287. * @return Returns Status.
  288. */
  289. Status __stdcall GetStatus() const;
  290. private:
  291. OneClassSegmentationTrainConfig(const OneClassSegmentationTrainConfig& src); //=delete;
  292. OneClassSegmentationTrainConfig& operator=(const OneClassSegmentationTrainConfig& rhs); //=delete;
  293. void *internal;
  294. Status m_status;
  295. };
  296. /**
  297. * @brief This is a configuration of the DetectionTrainer.
  298. * @details It is used as detection train setting parameter for DetectionTrainer.
  299. * @author ⓒSualab SuaKIT Team
  300. */
  301. class SUAKIT_API DetectionTrainConfig {
  302. public:
  303. /**
  304. * @brief This is a constructor of the DetectionTrainConfig.
  305. * @param numberOfClasses The number of classes to train. In detection training case, this parameter should be calculated except unlabeled class.
  306. * @param validationRatio Ratio of validation set. If user does not provide validation set, validation set is split from train set by this ratio.
  307. This ratio is calculated as ((validationRatio/100)*100)% of train set.
  308. * @param epochs In training neural network, one epoch means one sweep of the full training set. This is converted to an iteration according to the number of train sets.
  309. * @param modelCapacity The base (pre-defined) model capacity for executing DetectionTrainer. {0, 1, 2, 3}.
  310. - 0 means small model capacity.
  311. - 1 means normal model capacity.
  312. - 2 means large model capacity.
  313. - 3 means extra large model capacity.
  314. * @param dataRatio This is a weight parameter for each class. If you want to give a larger weight to specific class, you can increase this array's value corresponding index to the location of the specific class.
  315. * @param imageHeight Height of image for train. If the learning image size is variable, set it to the largest height value in the training image.
  316. * @param imageWidth Width of images for train. If the learning image size is variable, set it to the largest width value in the training image.
  317. * @param imageChannel The number of channels of images for train.
  318. * @param anchors Size of anchor boxes.
  319. Actual shape : (2, N) s.t. N = number of anchor box
  320. if you want to apply 3 anchor boxes, then Actual shape will be = (2, 3) and N = 3
  321. and configuration of array would be { width1, width2, width3, height1, height2, height3 }
  322. Value range : [0, 1] => (normalized to image size)
  323. * @param augmentationConfig Augmentation parameters for train.
  324. * @param inputDataType This parameter is used to configure input type of neural network. Set it to one of {InputDataType::SINGLE, InputDataType::COMPARISON, InputDataType::MULTIMG}.
  325. InputDataType::COMPARISON and InputDataType::MULTIMG are currently not supported. These features will be updated later.
  326. * @param multiImgCount The number of images for detection trainer, depends on inputDataType parameter.
  327. - In case of InputDataType::SINGLE, multiImgCount parameter needs to be initialized with 1.
  328. - In case of InputDataType::COMPARISON, multiImgCount parameter needs to be initialized with 2.
  329. - In case of InputDataType::MULTIMG, multiImgCount parameter needs to be initialized with multiImgCount >= 2.
  330. * @param imageDepth Depth of images for train.
  331. * @param detectionModelType This parameter is used to configure model type of detection. Set it to one of {ModelType::DETECTION_DOLPHIN, ModelType::DETECTION_DRAGON}.
  332. * @param numModelSelection Save the top N (numModelSelection) models based on the best loss.
  333. If the number of generated models is less than this value, only the currently created models are saved.
  334. * @param minimumEpoch The minimum number of epochs to save the model.
  335. * @param patience The number of epochs to wait before early stop if no progress on the validation set. The patience is often set somewhere between 10 and 100 (10 or 20 is more common), but it really depends on your dataset.
  336. - In case of DetectionTrainer, the learning rate decay is not performed.
  337. */
  338. __stdcall DetectionTrainConfig(
  339. SuaKIT_Int32 numberOfClasses,
  340. SuaKIT_Int32 validationRatio,
  341. SuaKIT_Int32 epochs,
  342. SuaKIT_Int32 modelCapacity,
  343. const Int32Array& dataRatio,
  344. SuaKIT_Int32 imageHeight,
  345. SuaKIT_Int32 imageWidth,
  346. SuaKIT_Int32 imageChannel,
  347. const FloatArray& anchors,
  348. const AugmentationConfig& augmentationConfig = AugmentationConfig(),
  349. InputDataType inputDataType = InputDataType::SINGLE,
  350. SuaKIT_Int32 multiImgCount = 1,
  351. DepthType imageDepth = DepthType::_8U,
  352. ModelType detectionModelType = ModelType::DETECTION_DOLPHIN,
  353. SuaKIT_Int32 numModelSelection = 1,
  354. float minimumEpoch = 0.0f,
  355. SuaKIT_Int64 patience = 9223372036854775807ll
  356. );
  357. __stdcall ~DetectionTrainConfig();
  358. void* _GetInternal() const { return internal; }
  359. /**
  360. * @brief Methods to set the class name
  361. * @param classIdx Class Index
  362. * @param className String to set as the name of class corresponding to Class Index.
  363. On Windows, it should be mbcs encoding, and
  364. On Linux, this should be utf8 encoding.
  365. * @return Returns Status.
  366. */
  367. Status SetClassName(SuaKIT_Int32 classIdx, const char * className);
  368. /**
  369. * @return Returns Status.
  370. */
  371. Status __stdcall GetStatus() const;
  372. private:
  373. DetectionTrainConfig(const DetectionTrainConfig& src); //=delete;
  374. DetectionTrainConfig& operator=(const DetectionTrainConfig& rhs); //=delete;
  375. void *internal;
  376. Status m_status;
  377. };
  378. /**
  379. * @brief This class is used to manage data for classification train.
  380. * @details This class contains image and label information for training ClassificationTrainer.
  381. * @author ⓒSualab SuaKIT Team
  382. */
  383. class SUAKIT_API ClassificationTrainData {
  384. public:
  385. /**
  386. * @brief A constructor of ClassificationTrainData.
  387. * @param inputDataType This parameter is used to configure input type of neural network. It initializes {InputDataType::SINGLE, InputDataType::COMPARISON, InputDataType::MULTIMG}.
  388. */
  389. __stdcall ClassificationTrainData(InputDataType inputDataType = SINGLE);
  390. __stdcall ~ClassificationTrainData();
  391. /**
  392. * @brief Push single data for train ClassificationTrainer. It should be used for InputDataType::SINGLE only.
  393. * @param imgPath A path of image to be trained.
  394. * @param classNum The class identifier for the image. If the total number of classes is (N), this parameter should be a value less than N (classNum <N).
  395. * @param maskImagePath A path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masked and 255 means not masked.
  396. * @param roi ROI(region of interest) rect.
  397. * @return Returns Status.
  398. */
  399. Status __stdcall PushSingleDataInfo(const wchar_t* imgPath, SuaKIT_Int32 classNum, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  400. /**
  401. * @brief Push single data for train ClassificationTrainer. It should be used for InputDataType::SINGLE only.
  402. * @param imgPath A path of image to be trained.
  403. * @param labelImgPath A path of label image. The label image is an 8-bit 1-channel image, with each pixel value representing a class number and an unlabeled area having an 255 pixel value.
  404. * @param classNum The class identifier for the image. If the total number of classes is (N), this parameter should be a value less than N (classNum <N).
  405. * @param maskImagePath A path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masked and 255 means not masked.
  406. * @param roi ROI(region of interest) rect.
  407. * @return Returns Status.
  408. */
  409. Status __stdcall PushSingleDataInfo(const wchar_t* imgPath, const wchar_t* labelImgPath, SuaKIT_Int32 classNum, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  410. /**
  411. * @brief Push comparison data for train ClassificationTrainer. It should be used for InputDataType::COMPARISON only.
  412. * @param masterImagePath A path of master image to be trained.
  413. * @param slaveImagePath A path of slave image to be trained.
  414. * @param classNum The class identifier for the image. If the total number of classes is (N), this parameter should be a value less than N (classNum <N).
  415. * @param maskImagePath Path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masked and 255 means not masked.
  416. * @param roi ROI(region of interest) rect.
  417. * @return Returns Status.
  418. */
  419. Status __stdcall PushComparisonDataInfo(const wchar_t* masterImagePath, const wchar_t* slaveImagePath, SuaKIT_Int32 classNum, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  420. /**
  421. * @brief Push single data for train ClassificationTrainer. It should be used for InputDataType::MULTIMG only.
  422. * @param imagePaths Paths of images to be trained. It takes the path of images as a string separated by newline characters ('\n'). For example, the input format should be "C:\Folder\img1.png\nC:\Folder\img2.png\nC:\Folder\img3.png" and the order of the images is important.
  423. * @param classNum The class identifier for the image. If the total number of classes is (N), this parameter should be a value less than N (classNum <N).
  424. * @param maskImagePath A path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masked and 255 means not masked.
  425. * @param roi ROI(region of interest) rect.
  426. * @return Returns Status.
  427. */
  428. Status __stdcall PushMultiDataInfo(const wchar_t* imagePaths, SuaKIT_Int32 classNum, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  429. Status __stdcall PushSingleDataInfoWithJson(const wchar_t* imgPath, SuaKIT_Int32 classNum, const wchar_t* maskImageJson, const Rect& roi = Rect());
  430. Status __stdcall PushSingleDataInfoWithJson(const wchar_t* imgPath, const wchar_t* labelJson, SuaKIT_Int32 classNum, const wchar_t* maskImageJson, const Rect& roi = Rect());
  431. Status __stdcall PushComparisonDataInfoWithJson(const wchar_t* masterImagePath, const wchar_t* slaveImagePath, SuaKIT_Int32 classNum, const wchar_t* maskImageJson, const Rect& roi = Rect());
  432. Status __stdcall PushMultiDataInfoWithJson(const wchar_t* imagePaths, SuaKIT_Int32 classNum, const wchar_t* maskImageJson, const Rect& roi = Rect());
  433. //softlabel input
  434. Status __stdcall PushSingleDataInfo(const wchar_t* imgPath, const FloatArray& softlabel, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  435. Status __stdcall PushComparisonDataInfo(const wchar_t* masterImagePath, const wchar_t* slaveImagePath, const FloatArray& softlabel, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  436. Status __stdcall PushMultiDataInfo(const wchar_t* imagePaths, const FloatArray& softlabel, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  437. /**
  438. * @brief Get current ClassificationTrainData's InputDataType.
  439. * @return Returns InputDataType.
  440. */
  441. InputDataType __stdcall GetDataType() const;
  442. /**
  443. * @brief Get length of current ClassificationTrainData.
  444. * @return Returns the length of image data as SuaKIT_Int64 type.
  445. */
  446. SuaKIT_Int64 __stdcall GetTrainDataLength() const;
  447. void* _getInternal() const { return internal; }
  448. /**
  449. * @brief Explicit resource release. It is useful in non-native languages (such as C#).
  450. * @return Returns Status.
  451. */
  452. Status __stdcall Destroy();
  453. /**
  454. * @return Returns Status.
  455. */
  456. Status __stdcall GetStatus() const;
  457. private:
  458. ClassificationTrainData(const ClassificationTrainData& src); //=delete;
  459. ClassificationTrainData& operator=(const ClassificationTrainData& rhs); //=delete;
  460. void* internal;
  461. InputDataType m_dataType;
  462. Status m_status;
  463. SuaKIT_UInt64 m_flags;
  464. };
  465. /**
  466. * @brief This class is used to manage data for Post train.
  467. * @details This class contains image and label information for training PostTrainer.
  468. * @author ⓒSualab SuaKIT Team
  469. */
  470. class SUAKIT_API PostTrainData {
  471. public:
  472. /**
  473. * @brief A constructor of PostTrainData.
  474. */
  475. __stdcall PostTrainData();
  476. __stdcall ~PostTrainData();
  477. /**
  478. * @brief Push data for train PostTrainer. PostTrainer support InputDataType::SINGLE only.
  479. * @param imgPath A path of image to be trained.
  480. * @param classNum The class identifier for the image. If the total number of classes is (N), this parameter should be a value less than N (classNum <N).
  481. You should take class number order like below.
  482. { 0 : old_class_0, 1 : old_class_1, ... k : old_class_k, k+1 : new_class_0, k+2 : new_class_1, ... N-1 : new_class_(N-k-2) }
  483. * @param maskImagePath A path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masked and 255 means not masked.
  484. * @param roi ROI(region of interest) rect.
  485. * @return Returns Status.
  486. */
  487. Status __stdcall PushDataInfo(const wchar_t* imgPath, SuaKIT_Int32 classNum, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  488. Status __stdcall PushDataInfoWithJson(const wchar_t* imgPath, SuaKIT_Int32 classNum, const wchar_t* maskImageJson, const Rect& roi = Rect());
  489. /**
  490. * @brief Get length of current PostTrainData.
  491. * @return Returns the length of image data as SuaKIT_Int64 type.
  492. */
  493. SuaKIT_Int64 __stdcall GetTrainDataLength() const;
  494. void* _getInternal() const { return internal; }
  495. /**
  496. * @brief Explicit resource release. It is useful in non-native languages (such as C#).
  497. * @return Returns Status.
  498. */
  499. Status __stdcall Destroy();
  500. /**
  501. * @return Returns Status.
  502. */
  503. Status __stdcall GetStatus() const;
  504. private:
  505. PostTrainData(const PostTrainData& src); //=delete;
  506. PostTrainData& operator=(const PostTrainData& rhs); //=delete;
  507. void* internal;
  508. Status m_status;
  509. };
  510. /**
  511. * @brief This class is used to manage data for segmentation train.
  512. * @details This class contains image and label information for training SegmentationTrainer.
  513. * @author ⓒSualab SuaKIT Team
  514. */
  515. class SUAKIT_API SegmentationTrainData {
  516. public:
  517. /**
  518. * @brief A constructor of SegmentationTrainData.
  519. * @param inputDataType This parameter used for configurate input type of neural network. It initializes {InputDataType::SINGLE, InputDataType::COMPARISON, InputDataType::MULTIMG}.
  520. */
  521. __stdcall SegmentationTrainData(InputDataType inputDataType = SINGLE);
  522. __stdcall ~SegmentationTrainData();
  523. /**
  524. * @brief Push single data for train SegmentationTrainer. It should be used for InputDataType::SINGLE only.
  525. * @param imgPath A path of image to be trained.
  526. * @param labelImgPath A path of label image. The label image is an 8-bit 1-channel image, with each pixel value representing a class number and an unlabeled area having an 255 pixel value.
  527. * @param maskImagePath A path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masking. 255 means not masked.
  528. * @param roi ROI(region of interest) rect.
  529. * @return Returns Status.
  530. */
  531. Status __stdcall PushSingleDataInfo(const wchar_t* imgPath, const wchar_t* labelImgPath, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  532. /**
  533. * @brief Push comparison data for train SegmentationTrainer. It should be used for InputDataType::COMPARISON only.
  534. * @param masterImagePath A path of master image to be trained.
  535. * @param slaveImagePath A path of slave image to be trained.
  536. * @param labelImgPath A paThe label image is an 8-bit 1-channel image, with each pixel value representing a class number and an unlabel area having an 255 pixel value.
  537. * @param maskImagePath Path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masking. 255 means not masked.
  538. * @param roi ROI(region of interest) rect.
  539. * @return Returns Status.
  540. */
  541. Status __stdcall PushComparisonDataInfo(const wchar_t* masterImagePath, const wchar_t* slaveImagePath, const wchar_t* labelImgPath, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  542. /**
  543. * @brief Push single data for train SegmentationTrainer. It should be used for InputDataType::MULTIMG only.
  544. * @param imagePaths Paths of images to be trained. It takes the path of images as a string separated by newline characters ('\n'). For example, the input format should be "C:\Folder\img1.png\nC:\Folder\img2.png\nC:\Folder\img3.png" and the order of the images is important.
  545. * @param labelImgPath A path of label image. The label image is an 8-bit 1-channel image, with each pixel value representing a class number and an unlabel area having an 255 pixel value.
  546. * @param maskImagePath A path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masked and 255 means not masked.
  547. * @param roi ROI(region of interest) rect.
  548. * @return Returns Status.
  549. */
  550. Status __stdcall PushMultiDataInfo(const wchar_t* imagePaths, const wchar_t* labelImgPath, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  551. Status __stdcall PushSingleDataInfoWithJson(const wchar_t* imgPath, const wchar_t* labelJson, const wchar_t* maskImageJson, const Rect& roi = Rect());
  552. Status __stdcall PushComparisonDataInfoWithJson(const wchar_t* masterImagePath, const wchar_t* slaveImagePath, const wchar_t* labelJson, const wchar_t* maskImageJson, const Rect& roi = Rect());
  553. Status __stdcall PushMultiDataInfoWithJson(const wchar_t* imagePaths, const wchar_t* labelJson, const wchar_t* maskImageJson, const Rect& roi = Rect());
  554. /**
  555. * @brief Get current SegmentationTrainData's InputDataType.
  556. * @return Returns InputDataType.
  557. */
  558. InputDataType __stdcall GetDataType() const;
  559. /**
  560. * @brief Get length of current SegmentationTrainData.
  561. * @return Returns the length of image data as SuaKIT_Int64 type.
  562. */
  563. SuaKIT_Int64 __stdcall GetTrainDataLength() const;
  564. void* _getInternal() const { return internal; }
  565. /**
  566. * @brief Explicit resource release. It is useful in non-native languages (such as C#).
  567. * @return Returns Status.
  568. */
  569. Status __stdcall Destroy();
  570. /**
  571. * @return Returns Status.
  572. */
  573. Status __stdcall GetStatus() const;
  574. private:
  575. SegmentationTrainData(const SegmentationTrainData& src); //=delete;
  576. SegmentationTrainData& operator=(const SegmentationTrainData& rhs); //=delete;
  577. void* internal;
  578. InputDataType m_dataType;
  579. Status m_status;
  580. };
  581. /**
  582. * @brief This class is used to manage data for OneClassSegmentation train.
  583. * @details This class contains image information for training OneClassSegmentationTrainer.
  584. * @author ⓒSualab SuaKIT Team
  585. */
  586. class SUAKIT_API OneClassSegmentationTrainData {
  587. public:
  588. /**
  589. * @brief A constructor of OneClassSegmentationTrainData.
  590. * @param inputDataType This parameter used for configurate input type of neural network. It initializes {InputDataType::SINGLE, InputDataType::COMPARISON, InputDataType::MULTIMG}.
  591. Currently, OneClassSegmentation only support InputDataType::SINGLE.
  592. */
  593. __stdcall OneClassSegmentationTrainData(InputDataType inputDataType = SINGLE);
  594. __stdcall ~OneClassSegmentationTrainData();
  595. /**
  596. * @brief Push single data for train OneClassSegmentationTrainer. It should be used for InputDataType::SINGLE only.
  597. * @param imgPath A path of image to be trained.
  598. * @param maskImagePath A path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masking. 255 means not masked.
  599. * @param roi ROI(region of interest) rect.
  600. * @return Returns Status.
  601. */
  602. Status __stdcall PushSingleDataInfo(const wchar_t* imgPath, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  603. /**
  604. * @brief Push comparison data for train OneClassSegmentationTrainer. It should be used for InputDataType::COMPARISON only.
  605. This function is not supported currently.
  606. * @param masterImagePath A path of master image to be trained.
  607. * @param slaveImagePath A path of slave image to be trained.
  608. * @param maskImagePath Path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masking. 255 means not masked.
  609. * @param roi ROI(region of interest) rect.
  610. * @return Returns Status.
  611. */
  612. Status __stdcall PushComparisonDataInfo(const wchar_t* masterImagePath, const wchar_t* slaveImagePath, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  613. /**
  614. * @brief InputDataType::MULTIMG is currently not supported.
  615. Push single data for train OneClassSegmentationTrainer. It should be used for InputDataType::MULTIMG only.
  616. * @param imagePaths Paths of images to be trained. It takes the path of images as a string separated by newline characters ('\n'). For example, the input format should be "C:\Folder\img1.png\nC:\Folder\img2.png\nC:\Folder\img3.png" and the order of the images is important.
  617. * @param maskImagePath A path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masked and 255 means not masked.
  618. * @param roi ROI(region of interest) rect.
  619. * @return Returns Status.
  620. */
  621. Status __stdcall PushMultiDataInfo(const wchar_t* imagePaths, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  622. Status __stdcall PushSingleDataInfoWithJson(const wchar_t* imgPath, const wchar_t* maskImageJson, const Rect& roi = Rect());
  623. Status __stdcall PushComparisonDataInfoWithJson(const wchar_t* masterImagePath, const wchar_t* slaveImagePath, const wchar_t* maskImageJson, const Rect& roi = Rect());
  624. Status __stdcall PushMultiDataInfoWithJson(const wchar_t* imagePaths, const wchar_t* maskImageJson, const Rect& roi = Rect());
  625. /**
  626. * @brief Get current OneClassSegmentationTrainData's InputDataType.
  627. * @return Returns InputDataType.
  628. */
  629. InputDataType __stdcall GetDataType() const;
  630. /**
  631. * @brief Get length of current OneClassSegmentationTrainData.
  632. * @return Returns the length of image data as SuaKIT_Int64 type.
  633. */
  634. SuaKIT_Int64 __stdcall GetTrainDataLength() const;
  635. void* _getInternal() const { return internal; }
  636. /**
  637. * @brief Explicit resource release. It is useful in non-native languages (such as C#).
  638. * @return Returns Status.
  639. */
  640. Status __stdcall Destroy();
  641. /**
  642. * @return Returns Status.
  643. */
  644. Status __stdcall GetStatus() const;
  645. private:
  646. OneClassSegmentationTrainData(const OneClassSegmentationTrainData& src); //=delete;
  647. OneClassSegmentationTrainData& operator=(const OneClassSegmentationTrainData& rhs); //=delete;
  648. void* internal;
  649. InputDataType m_dataType;
  650. Status m_status;
  651. };
  652. /**
  653. * @brief This class is used to manage data for detection train.
  654. * @details This class contains image and label information for training DetectionTrainer.
  655. * @author ⓒSualab SuaKIT Team
  656. */
  657. class SUAKIT_API DetectionTrainData {
  658. public:
  659. /**
  660. * @brief A constructor of DetectionTrainData.
  661. * @param inputDataType This parameter is used to configure input type of neural network. Set it to one of {InputDataType::SINGLE, InputDataType::COMPARISON, InputDataType::MULTIMG}.
  662. InputDataType::COMPARISON and InputDataType::MULTIMG are currently not supported. These features will be updated later.
  663. */
  664. __stdcall DetectionTrainData(InputDataType inputDataType = SINGLE);
  665. __stdcall ~DetectionTrainData();
  666. /**
  667. * @brief Push single data for train DetectionTrainer. It should be used for InputDataType::SINGLE only.
  668. * @param imgPath A path of image to be trained.
  669. * @param labelRects The labelRects is an array of rectangle information, which include rect coordinate (x, y, width, height) and class number corresponding to each object.
  670. * @param maskImagePath A path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masked and 255 means not masked.
  671. * @param roi ROI(region of interest) rect.
  672. * @return Returns Status.
  673. */
  674. Status __stdcall PushSingleDataInfo(const wchar_t* imgPath, const RectArray& labelRects, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  675. /**
  676. * @brief InputDataType::COMPARISON is currently not supported. This feature will be updated later.
  677. Push comparison data for train DetectionTrainer. It should be used for InputDataType::COMPARISON only.
  678. * @param masterImagePath A path of master image to be trained.
  679. * @param slaveImagePath A path of replica image to be trained.
  680. * @param labelRects The labelRects is an array of rectangle information, which include rect coordinate (x, y, width, height) and class number corresponding to each object.
  681. * @param maskImagePath A path of mask image.The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masked and 255 means not masked.
  682. * @param roi ROI(region of interest) rect.
  683. * @return Returns Status.
  684. */
  685. Status __stdcall PushComparisonDataInfo(const wchar_t* masterImagePath, const wchar_t* slaveImagePath, const RectArray& labelRects, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  686. /**
  687. * @brief InputDataType::MULTIMG is currently not supported. This feature will be updated later.
  688. Push single data for train DetectionTrainer. It should be used for InputDataType::MULTIMG only.
  689. * @param imagePaths Paths of images to be trained. It takes the path of images as a string separated by newline characters ('\n'). For example, the input format should be "C:\Folder\img1.png\nC:\Folder\img2.png\nC:\Folder\img3.png" and the order of the images is important.
  690. * @param labelRects The labelRects is an array of rectangle information, which include rect coordinate (x, y, width, height) and class number corresponding to each object.
  691. * @param maskImagePath A path of mask image. The mask image must be a 1-channel image and needs to have a value 0 or 255, where 0 indicates masking. 255 means not masked.
  692. * @param roi ROI(region of interest) rect.
  693. * @return Returns Status.
  694. */
  695. Status __stdcall PushMultiDataInfo(const wchar_t* imagePaths, const RectArray& labelRects, const wchar_t* maskImagePath = 0, const Rect& roi = Rect());
  696. Status __stdcall PushSingleDataInfoWithJson(const wchar_t* imgPath, const RectArray& labelRects, const wchar_t* maskImageJson, const Rect& roi = Rect());
  697. Status __stdcall PushComparisonDataInfoWithJson(const wchar_t* masterImagePath, const wchar_t* slaveImagePath, const RectArray& labelRects, const wchar_t* maskImageJson, const Rect& roi = Rect());
  698. Status __stdcall PushMultiDataInfoWithJson(const wchar_t* imagePaths, const RectArray& labelRects, const wchar_t* maskImageJson, const Rect& roi = Rect());
  699. /**
  700. * @brief Get current DetectionTrainData's InputDataType.
  701. * @return Returns InputDataType.
  702. */
  703. InputDataType __stdcall GetDataType() const;
  704. /**
  705. * @brief Get length of current DetectionTrainData.
  706. * @return Returns the length of image data as SuaKIT_Int64 type.
  707. */
  708. SuaKIT_Int64 __stdcall GetTrainDataLength() const;
  709. void* _getInternal() const { return internal; }
  710. /**
  711. * @brief Explicit resource release. It is useful in non-native languages (such as C#).
  712. * @return Returns Status.
  713. */
  714. Status __stdcall Destroy();
  715. /**
  716. * @return Returns Status.
  717. */
  718. Status __stdcall GetStatus() const;
  719. private:
  720. DetectionTrainData(const DetectionTrainData& src); //=delete;
  721. DetectionTrainData& operator=(const DetectionTrainData& rhs); //=delete;
  722. void* internal;
  723. InputDataType m_dataType;
  724. Status m_status;
  725. };
  726. /**
  727. * @brief This class is used at classification training time.
  728. * @details This class can produce a new artificial neural network that can solve classification problem. The user must prepare a ClassificationTrainConfig object and a ClassificationTrainData object.
  729. * @author ⓒSualab SuaKIT Team
  730. */
  731. class SUAKIT_API ClassificationTrainer {
  732. public:
  733. /**
  734. * @brief A constructor of ClassificationTrainer.
  735. * @param inputDataType This parameter used for configurate input type of neural network. Set it to one of {InputDataType::SINGLE, InputDataType::COMPARISON, InputDataType::MULTIMG}.
  736. * @param devDesc Specifies the device on which the ClassificationTrainer will operate. ClassificationTrainer can operate on one device and support CPU and GPU.
  737. * @param traincfg Configuration of ClassificationTrainer.
  738. * @param modelSavePath After the train is finished, the ClassificationTrainer will save the best model to disk and specify the path to the best model to be stored. The best model is saved under the path specified by modelSavePath.
  739. * @param baseModelPath The base (pre-defined) model path for executing ClassificationTrainer.
  740. As using backslash(`\`) characters in `const wchar_t*` strings is inconvenient for a variety of reasons, we recommend using paths with forward slashes(`/`).
  741. For example, `L"C:/suakit/weights/small"` is preferred over `L"C:\\suakit\\weights\\small"`.
  742. There is a file named {small, normal, large, extra_large} inside the weights folder in the path where the suakit is installed.
  743. This is the structure of the pre-defined model saved as a file.
  744. - small refers to the small model.
  745. - normal refers to the normal model.
  746. - large refers to the large model.
  747. - extra_large refers to the extra large model.
  748. You can also set this value to trained model path for continue training mode. (Currently, single classification continue training mode is supported only)
  749. * @return Returns Status.
  750. */
  751. __stdcall ClassificationTrainer(InputDataType inputDataType, const DeviceDescriptor& devDesc, const ClassificationTrainConfig& traincfg, const wchar_t * modelSavePath, const wchar_t * baseModelPath = 0); // baseModelPath = pretrained model
  752. __stdcall ~ClassificationTrainer();
  753. /**
  754. * @return Returns Status.
  755. */
  756. Status __stdcall GetStatus() const;
  757. /**
  758. * @brief Get current ClassificationTrainer's InputDataType.
  759. * @return Returns InputDataType.
  760. */
  761. InputDataType __stdcall GetDataType() const;
  762. /**
  763. * @brief During ClassificationTrainer object is training, this function returns true. Else return false.
  764. * @return Returns boolean parameter represents currently training status.
  765. */
  766. bool __stdcall isTraining() const;
  767. /**
  768. * @brief Trainer starts to train artificial neural network by calling this function.
  769. * @details This function takes just one argument. ClassificationTrainData.
  770. It works asynchronously and terminates the function almost simultaneously with function execution.
  771. This function internally creates a thread to train, so you do not need to write an explicit thread wrapper when you want to do something else.
  772. When users use this function, ClassificationTrainer will use a certain percentage of the training set as a validation set.
  773. * @param trainData ClassificationTrainData to train. See ClassificationTrainData for details.
  774. * @return Returns Status.
  775. */
  776. Status __stdcall StartTrain(const ClassificationTrainData& trainData);
  777. /**
  778. * @brief Trainer starts to train artificial neural network by calling this function.
  779. * @details This function takes just two arguments. They are both ClassificationTrainData. One is for train data, and the other is for validation data.
  780. It works asynchronously and terminates the function almost simultaneously with function execution.
  781. This function internally creates a thread to train, so you do not need to write an explicit thread wrapper when you want to do something else.
  782. This function explicitly sets the validation set and proceeds with the learning.
  783. * @param trainData ClassificationTrainData to train. See ClassificationTrainData for details.
  784. * @param validationData Learn by specifying explicit validation data.
  785. * @return Returns Status.
  786. */
  787. Status __stdcall StartTrain(const ClassificationTrainData& trainData, const ClassificationTrainData& validationData);
  788. /**
  789. * @brief Blocking function until the train is finished.
  790. * @details Calling this function will block until the train is terminated.
  791. * @return Returns Status.
  792. */
  793. Status __stdcall WaitTrain();
  794. /**
  795. * @brief Returns a Message representing the current training state.
  796. * @details ClassificationTrainer pushes its message into an internally managed message queue and proceeds with learning as it progresses.
  797. When the GetTrainMessage function is called, the internally managed message in the queue is dequeued and returned to the user as a Message object.
  798. If there is no message in the internal message queue, it will be blocked until a message is enqueued. If a message is enqueued, the message will be dequeued and returned immediately.
  799. * @return Returns Message object.
  800. */
  801. Message __stdcall GetTrainMessage();
  802. /**
  803. * @brief Returns internal Message queue is empty
  804. * @details ClassificationTrainer has message queue for communication. This function returns true when that queue is empty, otherwise return false.
  805. * @return is message queue is empty
  806. */
  807. bool __stdcall MessageIsEmpty();
  808. /**
  809. * @brief If training is currently in progress, we stop training. If the training is not in progress, nothing happens on the call.
  810. * @return Returns Status.
  811. */
  812. Status __stdcall StopTrain();
  813. /**
  814. * @brief Explicit resource release. It is useful in non-native languages (such as C#).
  815. * @return Returns Status.
  816. */
  817. Status __stdcall Destroy();
  818. /**
  819. * @brief Returns the maximum Image count that can be used when training a multi-image type model.
  820. multiImgCount in ClassificationTrainConfig constructor argument, must not exceed this return value.
  821. * @details imageHeight Maximum height of image to use for training
  822. * @details imageWidth Maximum width of image to use for training
  823. * @details imageChannel channel of image to use for training
  824. * @return Returns he maximum Image count
  825. */
  826. static SuaKIT_Int64 GetMaxImageCount(
  827. SuaKIT_Int32 imageHeight,
  828. SuaKIT_Int32 imageWidth,
  829. SuaKIT_Int32 imageChannel
  830. );
  831. private:
  832. ClassificationTrainer(const ClassificationTrainer& src); //=delete;
  833. ClassificationTrainer& operator=(const ClassificationTrainer& rhs); //=delete;
  834. DeviceDescriptor m_devDesc;
  835. InputDataType m_dataType;
  836. Status m_status;
  837. void* internal;
  838. };
  839. /**
  840. * @brief This class is used at Post training time.
  841. * @details This class can produce a new artificial neural network that can solve post training problem. The user must prepare a ClassificationTrainConfig object and a ClassificationTrainData object.
  842. * @author ⓒSualab SuaKIT Team
  843. */
  844. class SUAKIT_API PostTrainer
  845. {
  846. public:
  847. /**
  848. * @brief A constructor of PostTrainer.
  849. * @param devDesc Specifies the device on which the PostTrainer will operate. PostTrainer can operate on one device and support CPU and GPU.
  850. * @param traincfg Configuration of PostTrainer.
  851. * @param modelSavePath After the train is finished, the PostTrainer will save the best model to disk and specify the path to the best model to be stored. The best model is saved under the path specified by modelSavePath.
  852. * @param trainedModelPath Trained model (by SuaKIT ClassificationTrainer or PostTrainer) path for executing PostTrainer.
  853. As using backslash(`\`) characters in `const wchar_t*` strings is inconvenient for a variety of reasons, we recommend using paths with forward slashes(`/`).
  854. For example, `L"C:/suakit/weights/small"` is preferred over `L"C:\\suakit\\weights\\small"`.
  855. * @return Returns Status.
  856. */
  857. __stdcall PostTrainer(const DeviceDescriptor& devDesc, const PostTrainConfig& traincfg, const wchar_t * modelSavePath, const wchar_t * trainedModelPath);
  858. __stdcall ~PostTrainer();
  859. /**
  860. * @return Returns Status.
  861. */
  862. Status __stdcall GetStatus() const;
  863. /**
  864. * @brief During PostTrainer object is training, this function returns true. Else return false.
  865. * @return Returns boolean parameter represents currently training status.
  866. */
  867. bool __stdcall isTraining() const;
  868. /**
  869. * @brief Trainer starts to train artificial neural network by calling this function.
  870. * @details This function takes just one argument. PostTrainData.
  871. It works asynchronously and terminates the function almost simultaneously with function execution.
  872. This function internally creates a thread to train, so you do not need to write an explicit thread wrapper when you want to do something else.
  873. When users use this function, PostTrainer will use a certain percentage of the training set as a validation set.
  874. * @param trainData PostTrainData to train. See PostTrainData for details.
  875. * @return Returns Status.
  876. */
  877. Status __stdcall StartTrain(const PostTrainData& trainData);
  878. /**
  879. * @brief Trainer starts to train artificial neural network by calling this function.
  880. * @details This function takes just two arguments. They are both PostTrainData. One is for train data, and the other is for validation data.
  881. It works asynchronously and terminates the function almost simultaneously with function execution.
  882. This function internally creates a thread to train, so you do not need to write an explicit thread wrapper when you want to do something else.
  883. This function explicitly sets the validation set and proceeds with the learning.
  884. * @param trainData PostTrainData to train. See PostTrainData for details.
  885. * @param validationData Learn by specifying explicit validation data.
  886. * @return Returns Status.
  887. */
  888. Status __stdcall StartTrain(const PostTrainData& trainData, const PostTrainData& validationData);
  889. /**
  890. * @brief Blocking function until the train is finished.
  891. * @details Calling this function will block until the train is terminated.
  892. * @return Returns Status.
  893. */
  894. Status __stdcall WaitTrain();
  895. /**
  896. * @brief Returns a Message representing the current training state.
  897. * @details PostTrainer pushes its message into an internally managed message queue and proceeds with learning as it progresses.
  898. When the GetTrainMessage function is called, the internally managed message in the queue is dequeued and returned to the user as a Message object.
  899. If there is no message in the internal message queue, it will be blocked until a message is enqueued. If a message is enqueued, the message will be dequeued and returned immediately.
  900. * @return Returns Message object.
  901. */
  902. Message __stdcall GetTrainMessage();
  903. /**
  904. * @brief Returns internal Message queue is empty
  905. * @details PostTrainer has message queue for communication. This function returns true when that queue is empty, otherwise return false.
  906. * @return is message queue is empty
  907. */
  908. bool __stdcall MessageIsEmpty();
  909. /**
  910. * @brief If training is currently in progress, we stop training. If the training is not in progress, nothing happens on the call.
  911. In several train steps on PostTrain, only current training step will be stopped. And next step would be proceed automatically.
  912. If you call this function when current training has not reached at main train process (on initialization step yet),
  913. then you can't proceed remained steps. You could not get result model.
  914. * @return Returns Status.
  915. */
  916. Status __stdcall StopCurrentTrainStep();
  917. /**
  918. * @brief If training is currently in progress, we abort all training steps. If the training is not in progress, nothing happens on the call.
  919. This function stops every remained steps on PostTraining. So, if you call this function, you could not get result model.
  920. * @return Returns Status.
  921. */
  922. Status __stdcall AbortAllTrain();
  923. /**
  924. * @brief Explicit resource release. It is useful in non-native languages (such as C#).
  925. * @return Returns Status.
  926. */
  927. Status __stdcall Destroy();
  928. private:
  929. PostTrainer(const PostTrainer& src); //=delete;
  930. PostTrainer& operator=(const PostTrainer& rhs); //=delete;
  931. DeviceDescriptor m_devDesc;
  932. Status m_status;
  933. void* internal;
  934. };
  935. /**
  936. * @brief This class is used at segmentation training time.
  937. * @details This class can produce a new artificial neural network that can solve segmentation problem. The user must prepare a SegmentationTrainConfig object and a SegmentationTrainData object.
  938. * @author ⓒSualab SuaKIT Team
  939. */
  940. class SUAKIT_API SegmentationTrainer {
  941. public:
  942. /**
  943. * @brief A constructor of SegmentationTrainer.
  944. * @param inputDataType This parameter used for configurate input type of neural network. Set it to one of {InputDataType::SINGLE, InputDataType::COMPARISON, InputDataType::MULTIMG}.
  945. * @param devDesc Specifies the device on which the SegmentationTrainer will operate. SegmentationTrainer can operate on one device and supports CPU and GPU.
  946. * @param traincfg Configuration of SegmentationTrainer.
  947. * @param modelSavePath After the train is finished, the SegmentationTrainer will save the best model to disk and specify the path to the best model to be stored. The best model is saved under the path specified by modelSavePath.
  948. * @param baseModelPath The base (pre-defined) model path for executing SegmentationTrainer.
  949. As using backslash(`\`) characters in `const wchar_t*` strings is inconvenient for a variety of reasons, we recommend using paths with forward slashes(`/`).
  950. For example, `L"C:/suakit/weights/small"` is preferred over `L"C:\\suakit\\weights\\small"`.
  951. There is a file named {small, normal, large, extra_large} inside the weights folder in the path where the suakit is installed.
  952. This is the structure of the pre-defined model saved as a file.
  953. - small refers to the small model.
  954. - normal refers to the normal model.
  955. - large refers to the large model.
  956. - extra_large refers to the extra large model.
  957. You can also set this value to trained model path for continue training mode. (Currently, single classification continue training mode is supported only)
  958. * @return Returns Status.
  959. */
  960. __stdcall SegmentationTrainer(InputDataType inputDataType, const DeviceDescriptor& devDesc, const SegmentationTrainConfig& traincfg, const wchar_t * modelSavePath, const wchar_t * baseModelPath = 0);
  961. __stdcall ~SegmentationTrainer();
  962. /**
  963. * @return Returns Status.
  964. */
  965. Status __stdcall GetStatus() const;
  966. /**
  967. * @brief Get current SegmentationTrainer's InputDataType.
  968. * @return Returns InputDataType.
  969. */
  970. InputDataType __stdcall GetDataType() const;
  971. /**
  972. * @brief During SegmentationTrainer object is training, this function returns true. Else return false.
  973. * @return Returns boolean parameter represents currently training status.
  974. */
  975. bool __stdcall isTraining() const;
  976. /**
  977. * @brief Trainer starts to train artificial neural network by calling this function.
  978. * @details This function takes just one argument. SegmentationTrainData.
  979. It works asynchronously and terminates the function almost simultaneously with function execution.
  980. This function internally creates a thread to train, so you do not need to write an explicit thread wrapper when you want to do something else.
  981. When users use this function, SegmentationTrainer will use a certain percentage of the training set as a validation set.
  982. * @param trainData SegmentationTrainData to train. See SegmentationTrainData for details.
  983. * @return Returns Status.
  984. */
  985. Status __stdcall StartTrain(const SegmentationTrainData& trainData);
  986. /**
  987. * @brief Trainer starts to train artificial neural network by calling this function.
  988. * @details This function takes just two arguments. They are both SegmentationTrainData. One is for train data, and the other is for validation data.
  989. It works asynchronously and terminates the function almost simultaneously with function execution.
  990. This function internally creates a thread to train, so you do not need to write an explicit thread wrapper when you want to do something else.
  991. This function explicitly sets the validation set and proceeds with the learning.
  992. * @param trainData SegmentationTrainData to train. See SegmentationTrainData for details.
  993. * @param validationData Learn by specifying explicit validation data.
  994. * @return Returns Status.
  995. */
  996. Status __stdcall StartTrain(const SegmentationTrainData& trainData, const SegmentationTrainData& validationData);
  997. /**
  998. * @brief Blocking function until the train is finished.
  999. * @details Calling this function will block until the train is terminated.
  1000. * @return Returns Status.
  1001. */
  1002. Status __stdcall WaitTrain();
  1003. /**
  1004. * @brief Returns a Message representing the current training state.
  1005. * @details SegmentationTrainer pushes its message into an internally managed message queue and proceeds with learning as it progresses.
  1006. When the GetTrainMessage function is called, the internally managed message in the queue is dequeued and returned to the user as a Message object.
  1007. If there is no message in the internal message queue, it will be blocked until a message is enqueued. If a message is enqueued, the message will be dequeued and returned immediately.
  1008. * @return Returns Message object.
  1009. */
  1010. Message __stdcall GetTrainMessage();
  1011. /**
  1012. * @brief Returns internal Message queue is empty
  1013. * @details ClassificationTrainer has message queue for communication. This function returns true when that queue is empty, otherwise return false.
  1014. * @return is message queue is empty
  1015. */
  1016. bool __stdcall MessageIsEmpty();
  1017. /**
  1018. * @brief If training is currently in progress, we stop training. If the training is not in progress, nothing happens on the call.
  1019. * @return Returns Status.
  1020. */
  1021. Status __stdcall StopTrain();
  1022. /**
  1023. * @brief Explicit resource release. It is useful in non-native languages (such as C#).
  1024. * @return Returns Status.
  1025. */
  1026. Status __stdcall Destroy();
  1027. private:
  1028. SegmentationTrainer(const SegmentationTrainer& src); //=delete;
  1029. SegmentationTrainer& operator=(const SegmentationTrainer& rhs); //=delete;
  1030. DeviceDescriptor m_devDesc;
  1031. InputDataType m_dataType;
  1032. Status m_status;
  1033. void* internal;
  1034. };
  1035. /**
  1036. * @brief This class is used at OneClassSegmentation training time.
  1037. * @details This class can produce a new artificial neural network that can solve OneClassSegmentation problem. The user must prepare a OneClassSegmentationTrainConfig object and a OneClassSegmentationTrainData object.
  1038. * @author ⓒSualab SuaKIT Team
  1039. */
  1040. class SUAKIT_API OneClassSegmentationTrainer {
  1041. public:
  1042. /**
  1043. * @brief A constructor of OneClassSegmentationTrainer.
  1044. * @param inputDataType This parameter used for configurate input type of neural network. Set it to one of {InputDataType::SINGLE, InputDataType::COMPARISON, InputDataType::MULTIMG}.
  1045. Currently, OneClassSegmentation only support InputDataType::SINGLE.
  1046. * @param devDesc Specifies the device on which the OneClassSegmentationTrainer will operate. OneClassSegmentationTrainer can operate on one device and supports CPU and GPU.
  1047. * @param traincfg Configuration of OneClassSegmentationTrainer.
  1048. * @param modelSavePath After the train is finished, the OneClassSegmentationTrainer will save the best model to disk and specify the path to the best model to be stored. The best model is saved under the path specified by modelSavePath.
  1049. * @param baseModelPath The base (pre-defined) model path for executing OneClassSegmentationTrainer.
  1050. Currently, pre-trained base model is not supported.
  1051. You can also set this value to trained model path for continue training mode. (Currently, single classification continue training mode is supported only)
  1052. * @return Returns Status.
  1053. */
  1054. __stdcall OneClassSegmentationTrainer(InputDataType inputDataType, const DeviceDescriptor& devDesc, const OneClassSegmentationTrainConfig& traincfg, const wchar_t * modelSavePath, const wchar_t * baseModelPath = 0);
  1055. __stdcall ~OneClassSegmentationTrainer();
  1056. /**
  1057. * @return Returns Status.
  1058. */
  1059. Status __stdcall GetStatus() const;
  1060. /**
  1061. * @brief Get current OneClassSegmentationTrainer's InputDataType.
  1062. * @return Returns InputDataType.
  1063. */
  1064. InputDataType __stdcall GetDataType() const;
  1065. /**
  1066. * @brief During OneClassSegmentationTrainer object is training, this function returns true. Else return false.
  1067. * @return Returns boolean parameter represents currently training status.
  1068. */
  1069. bool __stdcall isTraining() const;
  1070. /**
  1071. * @brief Trainer starts to train artificial neural network by calling this function.
  1072. * @details This function takes just one argument. OneClassSegmentationTrainData.
  1073. It works asynchronously and terminates the function almost simultaneously with function execution.
  1074. This function internally creates a thread to train, so you do not need to write an explicit thread wrapper when you want to do something else.
  1075. When users use this function, OneClassSegmentationTrainer will use a certain percentage of the training set as a validation set.
  1076. * @param trainData OneClassSegmentationTrainData to train. See OneClassSegmentationTrainData for details.
  1077. * @return Returns Status.
  1078. */
  1079. Status __stdcall StartTrain(const OneClassSegmentationTrainData& trainData);
  1080. /**
  1081. * @brief Trainer starts to train artificial neural network by calling this function.
  1082. * @details This function takes just two arguments. They are both OneClassSegmentationTrainData. One is for train data, and the other is for validation data.
  1083. It works asynchronously and terminates the function almost simultaneously with function execution.
  1084. This function internally creates a thread to train, so you do not need to write an explicit thread wrapper when you want to do something else.
  1085. This function explicitly sets the validation set and proceeds with the learning.
  1086. * @param trainData OneClassSegmentationTrainData to train. See OneClassSegmentationTrainData for details.
  1087. * @param validationData Learn by specifying explicit validation data.
  1088. * @return Returns Status.
  1089. */
  1090. Status __stdcall StartTrain(const OneClassSegmentationTrainData& trainData, const OneClassSegmentationTrainData& validationData);
  1091. /**
  1092. * @brief Blocking function until the train is finished.
  1093. * @details Calling this function will block until the train is terminated.
  1094. * @return Returns Status.
  1095. */
  1096. Status __stdcall WaitTrain();
  1097. /**
  1098. * @brief Returns a Message representing the current training state.
  1099. * @details OneClassSegmentationTrainer pushes its message into an internally managed message queue and proceeds with learning as it progresses.
  1100. When the GetTrainMessage function is called, the internally managed message in the queue is dequeued and returned to the user as a Message object.
  1101. If there is no message in the internal message queue, it will be blocked until a message is enqueued. If a message is enqueued, the message will be dequeued and returned immediately.
  1102. * @return Returns Message object.
  1103. */
  1104. Message __stdcall GetTrainMessage();
  1105. /**
  1106. * @brief Returns internal Message queue is empty
  1107. * @details ClassificationTrainer has message queue for communication. This function returns true when that queue is empty, otherwise return false.
  1108. * @return is message queue is empty
  1109. */
  1110. bool __stdcall MessageIsEmpty();
  1111. /**
  1112. * @brief If training is currently in progress, we stop training. If the training is not in progress, nothing happens on the call.
  1113. * @return Returns Status.
  1114. */
  1115. Status __stdcall StopTrain();
  1116. /**
  1117. * @brief Explicit resource release. It is useful in non-native languages (such as C#).
  1118. * @return Returns Status.
  1119. */
  1120. Status __stdcall Destroy();
  1121. private:
  1122. OneClassSegmentationTrainer(const OneClassSegmentationTrainer& src); //=delete;
  1123. OneClassSegmentationTrainer& operator=(const OneClassSegmentationTrainer& rhs); //=delete;
  1124. DeviceDescriptor m_devDesc;
  1125. InputDataType m_dataType;
  1126. Status m_status;
  1127. void* internal;
  1128. };
  1129. /**
  1130. * @brief This class is used at detection training time.
  1131. * @details This class can produce a new artificial neural network that can solve detection problem. The user must prepare a DetectionTrainConfig object and a DetectionTrainData object.
  1132. * @author ⓒSualab SuaKIT Team
  1133. */
  1134. class SUAKIT_API DetectionTrainer {
  1135. public:
  1136. /**
  1137. * @brief A constructor of DetectionTrainer.
  1138. * @param inputDataType This parameter used for configurate input type of neural network. Set it to one of {InputDataType::SINGLE, InputDataType::COMPARISON, InputDataType::MULTIMG}.
  1139. InputDataType::COMPARISON and InputDataType::MULTIMG are currently not supported. These features will be updated later.
  1140. * @param devDesc Specifies the device on which the DetectionTrainer will operate. DetectionTrainer can operate on one device and supports CPU and GPU.
  1141. * @param traincfg Configuration of DetectionTrainer.
  1142. * @param modelSavePath After the train is finished, the DetectionTrainer will save the best model to disk and specify the path to the best model to be stored. The best model is saved under the path specified by modelSavePath.
  1143. * @param baseModelPath The base (pre-defined) model path for executing DetectionTrainer.
  1144. As using backslash(`\`) characters in `const wchar_t*` strings is inconvenient for a variety of reasons, we recommend using paths with forward slashes(`/`).
  1145. For example, `L"C:/suakit/weights/small"` is preferred over `L"C:\\suakit\\weights\\small"`.
  1146. There is a file named {small, normal, detection_large, extra_large} inside the weights folder in the path where the suakit is installed.
  1147. This is the structure of the pre-defined model saved as a file.
  1148. - small refers to the small model.
  1149. - normal refers to the normal model.
  1150. - detection_large refers to the large model.
  1151. - extra_large refers to the extra large model.
  1152. You can also set this value to trained model path for continue training mode. (Currently, single classification continue training mode is supported only)
  1153. * @return Returns Status.
  1154. */
  1155. __stdcall DetectionTrainer(InputDataType inputDataType, const DeviceDescriptor& devDesc, const DetectionTrainConfig& traincfg, const wchar_t * modelSavePath, const wchar_t * baseModelPath = 0);
  1156. __stdcall ~DetectionTrainer();
  1157. /**
  1158. * @return Returns Status.
  1159. */
  1160. Status __stdcall GetStatus() const;
  1161. /**
  1162. * @brief Get current DetectionTrainer's InputDataType.
  1163. * @return Returns InputDataType.
  1164. */
  1165. InputDataType __stdcall GetDataType() const;
  1166. /**
  1167. * @brief During DetectionTrainer object is training, this function returns true. Else return false.
  1168. * @return Returns boolean parameter represents currently training status.
  1169. */
  1170. bool __stdcall isTraining() const;
  1171. /**
  1172. * @brief Trainer starts to train artificial neural network by calling this function.
  1173. * @details This function takes just one argument. DetectionTrainData.
  1174. It works asynchronously and terminates the function almost simultaneously with function execution.
  1175. This function internally creates a thread to train, so you do not need to write an explicit thread wrapper when you want to do something else.
  1176. When users use this function, DetectionTrainer will use a certain percentage of the training set as a validation set.
  1177. * @param trainData DetectionTrainData to train. See DetectionTrainData for details.
  1178. * @return Returns Status.
  1179. */
  1180. Status __stdcall StartTrain(const DetectionTrainData& trainData);
  1181. /**
  1182. * @brief Trainer starts to train artificial neural network by calling this function.
  1183. * @details This function takes just two arguments. They are both DetectionTrainData. One is for train data, and the other is for validation data.
  1184. It works asynchronously and terminates the function almost simultaneously with function execution.
  1185. This function internally creates a thread to train, so you do not need to write an explicit thread wrapper when you want to do something else.
  1186. This function explicitly sets the validation set and proceeds with the learning.
  1187. * @param trainData DetectionTrainData to train. See DetectionTrainData for details.
  1188. * @param validationData Learn by specifying explicit validation data.
  1189. * @return Returns Status.
  1190. */
  1191. Status __stdcall StartTrain(const DetectionTrainData& trainData, const DetectionTrainData& validationData);
  1192. /**
  1193. * @brief Blocking function until the train is finished.
  1194. * @details Calling this function will block until the train is terminated.
  1195. * @return Returns Status.
  1196. */
  1197. Status __stdcall WaitTrain();
  1198. /**
  1199. * @brief Returns a Message representing the current training state.
  1200. * @details DetectionTrainer pushes its message into an internally managed message queue and proceeds with learning as it progresses.
  1201. When the GetTrainMessage function is called, the internally managed message in the queue is dequeued and returned to the user as a Message object.
  1202. If there is no message in the internal message queue, it will be blocked until a message is enqueued. If a message is enqueued, the message will be dequeued and returned immediately.
  1203. * @return Returns Message object.
  1204. */
  1205. Message __stdcall GetTrainMessage();
  1206. /**
  1207. * @brief Returns internal Message queue is empty
  1208. * @details DetectionTrainer has message queue for communication. This function returns true when that queue is empty, otherwise return false.
  1209. * @return is message queue is empty
  1210. */
  1211. bool __stdcall MessageIsEmpty();
  1212. /**
  1213. * @brief If training is currently in progress, we stop training. If the training is not in progress, nothing happens on the call.
  1214. * @return Returns Status.
  1215. */
  1216. Status __stdcall StopTrain();
  1217. /**
  1218. * @brief Explicit resource release. It is useful in non-native languages (such as C#).
  1219. * @return Returns Status.
  1220. */
  1221. Status __stdcall Destroy();
  1222. private:
  1223. DetectionTrainer(const DetectionTrainer& src); //=delete;
  1224. DetectionTrainer& operator=(const DetectionTrainer& rhs); //=delete;
  1225. DeviceDescriptor m_devDesc;
  1226. InputDataType m_dataType;
  1227. Status m_status;
  1228. void* internal;
  1229. };
  1230. }
  1231. }