Files
SVMClassifier/svm__classifier_8hpp_source.html
2025-06-22 11:25:27 +00:00

239 lines
41 KiB
HTML

<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "https://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml" lang="en-US">
<head>
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
<meta http-equiv="X-UA-Compatible" content="IE=11"/>
<meta name="generator" content="Doxygen 1.9.8"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<title>SVM Classifier C++: include/svm_classifier/svm_classifier.hpp Source File</title>
<link href="tabs.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="jquery.js"></script>
<script type="text/javascript" src="dynsections.js"></script>
<link href="search/search.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="search/searchdata.js"></script>
<script type="text/javascript" src="search/search.js"></script>
<link href="doxygen.css" rel="stylesheet" type="text/css" />
</head>
<body>
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
<div id="titlearea">
<table cellspacing="0" cellpadding="0">
<tbody>
<tr id="projectrow">
<td id="projectalign">
<div id="projectname">SVM Classifier C++<span id="projectnumber">&#160;1.0.0</span>
</div>
<div id="projectbrief">High-performance Support Vector Machine classifier with scikit-learn compatible API</div>
</td>
</tr>
</tbody>
</table>
</div>
<!-- end header part -->
<!-- Generated by Doxygen 1.9.8 -->
<script type="text/javascript">
/* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&amp;dn=expat.txt MIT */
var searchBox = new SearchBox("searchBox", "search/",'.html');
/* @license-end */
</script>
<script type="text/javascript" src="menudata.js"></script>
<script type="text/javascript" src="menu.js"></script>
<script type="text/javascript">
/* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&amp;dn=expat.txt MIT */
$(function() {
initMenu('',true,false,'search.php','Search');
$(document).ready(function() { init_search(); });
});
/* @license-end */
</script>
<div id="main-nav"></div>
<script type="text/javascript">
/* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&amp;dn=expat.txt MIT */
$(document).ready(function() { init_codefold(0); });
/* @license-end */
</script>
<!-- window showing the filter options -->
<div id="MSearchSelectWindow"
onmouseover="return searchBox.OnSearchSelectShow()"
onmouseout="return searchBox.OnSearchSelectHide()"
onkeydown="return searchBox.OnSearchSelectKey(event)">
</div>
<!-- iframe showing the search results (closed by default) -->
<div id="MSearchResultsWindow">
<div id="MSearchResults">
<div class="SRPage">
<div id="SRIndex">
<div id="SRResults"></div>
<div class="SRStatus" id="Loading">Loading...</div>
<div class="SRStatus" id="Searching">Searching...</div>
<div class="SRStatus" id="NoMatches">No Matches</div>
</div>
</div>
</div>
</div>
<div id="nav-path" class="navpath">
<ul>
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_daf582bc00f2bbc6516ddb6630e28009.html">svm_classifier</a></li> </ul>
</div>
</div><!-- top -->
<div class="header">
<div class="headertitle"><div class="title">svm_classifier.hpp</div></div>
</div><!--header-->
<div class="contents">
<div class="fragment"><div class="line"><a id="l00001" name="l00001"></a><span class="lineno"> 1</span><span class="preprocessor">#pragma once</span></div>
<div class="line"><a id="l00002" name="l00002"></a><span class="lineno"> 2</span> </div>
<div class="line"><a id="l00003" name="l00003"></a><span class="lineno"> 3</span><span class="preprocessor">#include &quot;types.hpp&quot;</span></div>
<div class="line"><a id="l00004" name="l00004"></a><span class="lineno"> 4</span><span class="preprocessor">#include &quot;kernel_parameters.hpp&quot;</span></div>
<div class="line"><a id="l00005" name="l00005"></a><span class="lineno"> 5</span><span class="preprocessor">#include &quot;data_converter.hpp&quot;</span></div>
<div class="line"><a id="l00006" name="l00006"></a><span class="lineno"> 6</span><span class="preprocessor">#include &quot;multiclass_strategy.hpp&quot;</span></div>
<div class="line"><a id="l00007" name="l00007"></a><span class="lineno"> 7</span><span class="preprocessor">#include &lt;torch/torch.h&gt;</span></div>
<div class="line"><a id="l00008" name="l00008"></a><span class="lineno"> 8</span><span class="preprocessor">#include &lt;nlohmann/json.hpp&gt;</span></div>
<div class="line"><a id="l00009" name="l00009"></a><span class="lineno"> 9</span><span class="preprocessor">#include &lt;memory&gt;</span></div>
<div class="line"><a id="l00010" name="l00010"></a><span class="lineno"> 10</span><span class="preprocessor">#include &lt;string&gt;</span></div>
<div class="line"><a id="l00011" name="l00011"></a><span class="lineno"> 11</span> </div>
<div class="line"><a id="l00012" name="l00012"></a><span class="lineno"> 12</span><span class="keyword">namespace </span>svm_classifier {</div>
<div class="line"><a id="l00013" name="l00013"></a><span class="lineno"> 13</span> </div>
<div class="foldopen" id="foldopen00021" data-start="{" data-end="};">
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html"> 21</a></span> <span class="keyword">class </span><a class="code hl_class" href="classsvm__classifier_1_1SVMClassifier.html">SVMClassifier</a> {</div>
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"> 22</span> <span class="keyword">public</span>:</div>
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a3ed45cdbc3fc5d947320177f42115dcf"> 26</a></span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a3ed45cdbc3fc5d947320177f42115dcf">SVMClassifier</a>();</div>
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> </div>
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a2afb41f77de4e8de6368d274a30191ec"> 32</a></span> <span class="keyword">explicit</span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a2afb41f77de4e8de6368d274a30191ec">SVMClassifier</a>(<span class="keyword">const</span> nlohmann::json&amp; config);</div>
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> </div>
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a90b2f18dd2cfeb23cf1375f265e22db0"> 40</a></span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a90b2f18dd2cfeb23cf1375f265e22db0">SVMClassifier</a>(KernelType kernel,</div>
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span> <span class="keywordtype">double</span> C = 1.0,</div>
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span> MulticlassStrategy multiclass_strategy = MulticlassStrategy::ONE_VS_REST);</div>
<div class="line"><a id="l00043" name="l00043"></a><span class="lineno"> 43</span> </div>
<div class="line"><a id="l00047" name="l00047"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a233584f6696969ce1a402624fd046146"> 47</a></span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a233584f6696969ce1a402624fd046146">~SVMClassifier</a>();</div>
<div class="line"><a id="l00048" name="l00048"></a><span class="lineno"> 48</span> </div>
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a377b6082ac4153be3197ef70c1c82984"> 52</a></span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a377b6082ac4153be3197ef70c1c82984">SVMClassifier</a>(<span class="keyword">const</span> <a class="code hl_class" href="classsvm__classifier_1_1SVMClassifier.html">SVMClassifier</a>&amp;) = <span class="keyword">delete</span>;</div>
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> </div>
<div class="line"><a id="l00057" name="l00057"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a209902805c75e8f22c55575adfedc7be"> 57</a></span> <a class="code hl_class" href="classsvm__classifier_1_1SVMClassifier.html">SVMClassifier</a>&amp; <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a209902805c75e8f22c55575adfedc7be">operator=</a>(<span class="keyword">const</span> <a class="code hl_class" href="classsvm__classifier_1_1SVMClassifier.html">SVMClassifier</a>&amp;) = <span class="keyword">delete</span>;</div>
<div class="line"><a id="l00058" name="l00058"></a><span class="lineno"> 58</span> </div>
<div class="line"><a id="l00062" name="l00062"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#ae2eafdc66d1907c145efffd186dfff3f"> 62</a></span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#ae2eafdc66d1907c145efffd186dfff3f">SVMClassifier</a>(<a class="code hl_class" href="classsvm__classifier_1_1SVMClassifier.html">SVMClassifier</a>&amp;&amp;) noexcept;</div>
<div class="line"><a id="l00063" name="l00063"></a><span class="lineno"> 63</span> </div>
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a49b6a4a5ae8a8e0eaf24221482be3d6a"> 67</a></span> <a class="code hl_class" href="classsvm__classifier_1_1SVMClassifier.html">SVMClassifier</a>&amp; operator=(<a class="code hl_class" href="classsvm__classifier_1_1SVMClassifier.html">SVMClassifier</a>&amp;&amp;) noexcept;</div>
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"> 68</span> </div>
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a7e6648c4d2bac92bb00381076ea92db3"> 77</a></span> <a class="code hl_struct" href="structsvm__classifier_1_1TrainingMetrics.html">TrainingMetrics</a> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a7e6648c4d2bac92bb00381076ea92db3">fit</a>(const torch::Tensor&amp; X, const torch::Tensor&amp; y);</div>
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> </div>
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a5c998d5574b3b6afe003b23ed02ed1d1"> 85</a></span> torch::Tensor <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a5c998d5574b3b6afe003b23ed02ed1d1">predict</a>(const torch::Tensor&amp; X);</div>
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> </div>
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#ab4ef3c839e085ece646cdd2501a51f67"> 93</a></span> torch::Tensor <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#ab4ef3c839e085ece646cdd2501a51f67">predict_proba</a>(const torch::Tensor&amp; X);</div>
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> </div>
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#ad153c0537998eae5fbca5fd0b5ead2b7"> 101</a></span> torch::Tensor <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#ad153c0537998eae5fbca5fd0b5ead2b7">decision_function</a>(const torch::Tensor&amp; X);</div>
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> </div>
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a0479c57489c14be4a5ca79368086f7f6"> 110</a></span> <span class="keywordtype">double</span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a0479c57489c14be4a5ca79368086f7f6">score</a>(const torch::Tensor&amp; X, const torch::Tensor&amp; y_true);</div>
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> </div>
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a38a9b020b9f4f9254920c97a3a047e9b"> 118</a></span> <a class="code hl_struct" href="structsvm__classifier_1_1EvaluationMetrics.html">EvaluationMetrics</a> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a38a9b020b9f4f9254920c97a3a047e9b">evaluate</a>(const torch::Tensor&amp; X, const torch::Tensor&amp; y_true);</div>
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span> </div>
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#adb01e761fea07c709f3a0e315d3d0e06"> 125</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#adb01e761fea07c709f3a0e315d3d0e06">set_parameters</a>(const nlohmann::json&amp; config);</div>
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"> 126</span> </div>
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a7c39ec09b15186dcb4f04ae7171d23bb"> 131</a></span> nlohmann::json <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a7c39ec09b15186dcb4f04ae7171d23bb">get_parameters</a>() const;</div>
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"> 132</span> </div>
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a71a85ab7893e7e2b40763db34096d8bb"> 137</a></span> <span class="keywordtype">bool</span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a71a85ab7893e7e2b40763db34096d8bb">is_fitted</a>()<span class="keyword"> const </span>{ <span class="keywordflow">return</span> is_fitted_; }</div>
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"> 138</span> </div>
<div class="line"><a id="l00143" name="l00143"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a75d501339e2e2273082b0838e9caadcd"> 143</a></span> <span class="keywordtype">int</span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a75d501339e2e2273082b0838e9caadcd">get_n_classes</a>() <span class="keyword">const</span>;</div>
<div class="line"><a id="l00144" name="l00144"></a><span class="lineno"> 144</span> </div>
<div class="line"><a id="l00149" name="l00149"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#af0fea42cdfc9416ed854b0d4aefa82b9"> 149</a></span> std::vector&lt;int&gt; <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#af0fea42cdfc9416ed854b0d4aefa82b9">get_classes</a>() <span class="keyword">const</span>;</div>
<div class="line"><a id="l00150" name="l00150"></a><span class="lineno"> 150</span> </div>
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a780afcb2ad618e46541aff8a44e9c7b4"> 155</a></span> <span class="keywordtype">int</span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a780afcb2ad618e46541aff8a44e9c7b4">get_n_features</a>()<span class="keyword"> const </span>{ <span class="keywordflow">return</span> n_features_; }</div>
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> </div>
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a0b8c77f81d84489b2da0d080773a2970"> 161</a></span> <a class="code hl_struct" href="structsvm__classifier_1_1TrainingMetrics.html">TrainingMetrics</a> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a0b8c77f81d84489b2da0d080773a2970">get_training_metrics</a>()<span class="keyword"> const </span>{ <span class="keywordflow">return</span> training_metrics_; }</div>
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> </div>
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a3f8b4e932f075b267507ad77a499a135"> 167</a></span> <span class="keywordtype">bool</span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a3f8b4e932f075b267507ad77a499a135">supports_probability</a>() <span class="keyword">const</span>;</div>
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> </div>
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#ab8a0bd35705825e80a7567b576d47359"> 174</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#ab8a0bd35705825e80a7567b576d47359">save_model</a>(<span class="keyword">const</span> std::string&amp; filename) <span class="keyword">const</span>;</div>
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> </div>
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a583f5743acf5e6b850e079b9190989f1"> 181</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a583f5743acf5e6b850e079b9190989f1">load_model</a>(<span class="keyword">const</span> std::string&amp; filename);</div>
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> </div>
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a55338ab396bd5da923b6acbef8ed783a"> 187</a></span> KernelType <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a55338ab396bd5da923b6acbef8ed783a">get_kernel_type</a>()<span class="keyword"> const </span>{ <span class="keywordflow">return</span> params_.get_kernel_type(); }</div>
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"> 188</span> </div>
<div class="line"><a id="l00193" name="l00193"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a14c2f7917c8a91154c09160288509f2c"> 193</a></span> MulticlassStrategy <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a14c2f7917c8a91154c09160288509f2c">get_multiclass_strategy</a>()<span class="keyword"> const </span>{ <span class="keywordflow">return</span> params_.get_multiclass_strategy(); }</div>
<div class="line"><a id="l00194" name="l00194"></a><span class="lineno"> 194</span> </div>
<div class="line"><a id="l00199" name="l00199"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a38173e5cf0f6a4620f032fd54c28d592"> 199</a></span> SVMLibrary <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a38173e5cf0f6a4620f032fd54c28d592">get_svm_library</a>()<span class="keyword"> const </span>{ <span class="keywordflow">return</span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a38173e5cf0f6a4620f032fd54c28d592">get_svm_library</a>(params_.get_kernel_type()); }</div>
<div class="line"><a id="l00200" name="l00200"></a><span class="lineno"> 200</span> </div>
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a4c91072ea0d3d9b97ba458ff7d0898b8"> 208</a></span> std::vector&lt;double&gt; <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a4c91072ea0d3d9b97ba458ff7d0898b8">cross_validate</a>(<span class="keyword">const</span> torch::Tensor&amp; X,</div>
<div class="line"><a id="l00209" name="l00209"></a><span class="lineno"> 209</span> <span class="keyword">const</span> torch::Tensor&amp; y,</div>
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> <span class="keywordtype">int</span> cv = 5);</div>
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> </div>
<div class="line"><a id="l00220" name="l00220"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#afed66a704dfb38cc7d080d3337d10194"> 220</a></span> nlohmann::json <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#afed66a704dfb38cc7d080d3337d10194">grid_search</a>(<span class="keyword">const</span> torch::Tensor&amp; X,</div>
<div class="line"><a id="l00221" name="l00221"></a><span class="lineno"> 221</span> <span class="keyword">const</span> torch::Tensor&amp; y,</div>
<div class="line"><a id="l00222" name="l00222"></a><span class="lineno"> 222</span> <span class="keyword">const</span> nlohmann::json&amp; param_grid,</div>
<div class="line"><a id="l00223" name="l00223"></a><span class="lineno"> 223</span> <span class="keywordtype">int</span> cv = 5);</div>
<div class="line"><a id="l00224" name="l00224"></a><span class="lineno"> 224</span> </div>
<div class="line"><a id="l00230" name="l00230"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#a2ade33562381e34cbe4b04089545a715"> 230</a></span> torch::Tensor <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#a2ade33562381e34cbe4b04089545a715">get_feature_importance</a>() <span class="keyword">const</span>;</div>
<div class="line"><a id="l00231" name="l00231"></a><span class="lineno"> 231</span> </div>
<div class="line"><a id="l00235" name="l00235"></a><span class="lineno"><a class="line" href="classsvm__classifier_1_1SVMClassifier.html#aa2bd5715c9e54e3fb465a9bcbf2e9c8a"> 235</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classsvm__classifier_1_1SVMClassifier.html#aa2bd5715c9e54e3fb465a9bcbf2e9c8a">reset</a>();</div>
<div class="line"><a id="l00236" name="l00236"></a><span class="lineno"> 236</span> </div>
<div class="line"><a id="l00237" name="l00237"></a><span class="lineno"> 237</span> <span class="keyword">private</span>:</div>
<div class="line"><a id="l00238" name="l00238"></a><span class="lineno"> 238</span> KernelParameters params_; </div>
<div class="line"><a id="l00239" name="l00239"></a><span class="lineno"> 239</span> std::unique_ptr&lt;MulticlassStrategyBase&gt; multiclass_strategy_; </div>
<div class="line"><a id="l00240" name="l00240"></a><span class="lineno"> 240</span> std::unique_ptr&lt;DataConverter&gt; data_converter_; </div>
<div class="line"><a id="l00241" name="l00241"></a><span class="lineno"> 241</span> </div>
<div class="line"><a id="l00242" name="l00242"></a><span class="lineno"> 242</span> <span class="keywordtype">bool</span> is_fitted_; </div>
<div class="line"><a id="l00243" name="l00243"></a><span class="lineno"> 243</span> <span class="keywordtype">int</span> n_features_; </div>
<div class="line"><a id="l00244" name="l00244"></a><span class="lineno"> 244</span> <a class="code hl_struct" href="structsvm__classifier_1_1TrainingMetrics.html">TrainingMetrics</a> training_metrics_; </div>
<div class="line"><a id="l00245" name="l00245"></a><span class="lineno"> 245</span> </div>
<div class="line"><a id="l00252" name="l00252"></a><span class="lineno"> 252</span> <span class="keywordtype">void</span> validate_input(<span class="keyword">const</span> torch::Tensor&amp; X,</div>
<div class="line"><a id="l00253" name="l00253"></a><span class="lineno"> 253</span> <span class="keyword">const</span> torch::Tensor&amp; y = torch::Tensor(),</div>
<div class="line"><a id="l00254" name="l00254"></a><span class="lineno"> 254</span> <span class="keywordtype">bool</span> check_fitted = <span class="keyword">false</span>);</div>
<div class="line"><a id="l00255" name="l00255"></a><span class="lineno"> 255</span> </div>
<div class="line"><a id="l00259" name="l00259"></a><span class="lineno"> 259</span> <span class="keywordtype">void</span> initialize_multiclass_strategy();</div>
<div class="line"><a id="l00260" name="l00260"></a><span class="lineno"> 260</span> </div>
<div class="line"><a id="l00267" name="l00267"></a><span class="lineno"> 267</span> std::vector&lt;std::vector&lt;int&gt;&gt; calculate_confusion_matrix(<span class="keyword">const</span> std::vector&lt;int&gt;&amp; y_true,</div>
<div class="line"><a id="l00268" name="l00268"></a><span class="lineno"> 268</span> <span class="keyword">const</span> std::vector&lt;int&gt;&amp; y_pred);</div>
<div class="line"><a id="l00269" name="l00269"></a><span class="lineno"> 269</span> </div>
<div class="line"><a id="l00275" name="l00275"></a><span class="lineno"> 275</span> std::tuple&lt;double, double, double&gt; calculate_metrics_from_confusion_matrix(</div>
<div class="line"><a id="l00276" name="l00276"></a><span class="lineno"> 276</span> <span class="keyword">const</span> std::vector&lt;std::vector&lt;int&gt;&gt;&amp; confusion_matrix);</div>
<div class="line"><a id="l00277" name="l00277"></a><span class="lineno"> 277</span> </div>
<div class="line"><a id="l00286" name="l00286"></a><span class="lineno"> 286</span> std::tuple&lt;torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor&gt;</div>
<div class="line"><a id="l00287" name="l00287"></a><span class="lineno"> 287</span> split_for_cv(<span class="keyword">const</span> torch::Tensor&amp; X, <span class="keyword">const</span> torch::Tensor&amp; y, <span class="keywordtype">int</span> fold, <span class="keywordtype">int</span> n_folds);</div>
<div class="line"><a id="l00288" name="l00288"></a><span class="lineno"> 288</span> </div>
<div class="line"><a id="l00294" name="l00294"></a><span class="lineno"> 294</span> std::vector&lt;nlohmann::json&gt; generate_param_combinations(<span class="keyword">const</span> nlohmann::json&amp; param_grid);</div>
<div class="line"><a id="l00295" name="l00295"></a><span class="lineno"> 295</span> };</div>
</div>
<div class="line"><a id="l00296" name="l00296"></a><span class="lineno"> 296</span> </div>
<div class="line"><a id="l00297" name="l00297"></a><span class="lineno"> 297</span>} <span class="comment">// namespace svm_classifier</span></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html">svm_classifier::SVMClassifier</a></div><div class="ttdoc">Support Vector Machine Classifier with scikit-learn compatible API.</div><div class="ttdef"><b>Definition</b> <a href="svm__classifier_8hpp_source.html#l00021">svm_classifier.hpp:21</a></div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a0479c57489c14be4a5ca79368086f7f6"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a0479c57489c14be4a5ca79368086f7f6">svm_classifier::SVMClassifier::score</a></div><div class="ttdeci">double score(const torch::Tensor &amp;X, const torch::Tensor &amp;y_true)</div><div class="ttdoc">Calculate accuracy score on test data.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a0b8c77f81d84489b2da0d080773a2970"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a0b8c77f81d84489b2da0d080773a2970">svm_classifier::SVMClassifier::get_training_metrics</a></div><div class="ttdeci">TrainingMetrics get_training_metrics() const</div><div class="ttdoc">Get training metrics from last fit.</div><div class="ttdef"><b>Definition</b> <a href="svm__classifier_8hpp_source.html#l00161">svm_classifier.hpp:161</a></div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a14c2f7917c8a91154c09160288509f2c"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a14c2f7917c8a91154c09160288509f2c">svm_classifier::SVMClassifier::get_multiclass_strategy</a></div><div class="ttdeci">MulticlassStrategy get_multiclass_strategy() const</div><div class="ttdoc">Get multiclass strategy.</div><div class="ttdef"><b>Definition</b> <a href="svm__classifier_8hpp_source.html#l00193">svm_classifier.hpp:193</a></div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a209902805c75e8f22c55575adfedc7be"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a209902805c75e8f22c55575adfedc7be">svm_classifier::SVMClassifier::operator=</a></div><div class="ttdeci">SVMClassifier &amp; operator=(const SVMClassifier &amp;)=delete</div><div class="ttdoc">Copy assignment (deleted - models are not copyable)</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a233584f6696969ce1a402624fd046146"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a233584f6696969ce1a402624fd046146">svm_classifier::SVMClassifier::~SVMClassifier</a></div><div class="ttdeci">~SVMClassifier()</div><div class="ttdoc">Destructor.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a2ade33562381e34cbe4b04089545a715"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a2ade33562381e34cbe4b04089545a715">svm_classifier::SVMClassifier::get_feature_importance</a></div><div class="ttdeci">torch::Tensor get_feature_importance() const</div><div class="ttdoc">Get feature importance (for linear kernels only)</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a2afb41f77de4e8de6368d274a30191ec"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a2afb41f77de4e8de6368d274a30191ec">svm_classifier::SVMClassifier::SVMClassifier</a></div><div class="ttdeci">SVMClassifier(const nlohmann::json &amp;config)</div><div class="ttdoc">Constructor with JSON parameters.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a377b6082ac4153be3197ef70c1c82984"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a377b6082ac4153be3197ef70c1c82984">svm_classifier::SVMClassifier::SVMClassifier</a></div><div class="ttdeci">SVMClassifier(const SVMClassifier &amp;)=delete</div><div class="ttdoc">Copy constructor (deleted - models are not copyable)</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a38173e5cf0f6a4620f032fd54c28d592"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a38173e5cf0f6a4620f032fd54c28d592">svm_classifier::SVMClassifier::get_svm_library</a></div><div class="ttdeci">SVMLibrary get_svm_library() const</div><div class="ttdoc">Get SVM library being used.</div><div class="ttdef"><b>Definition</b> <a href="svm__classifier_8hpp_source.html#l00199">svm_classifier.hpp:199</a></div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a38a9b020b9f4f9254920c97a3a047e9b"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a38a9b020b9f4f9254920c97a3a047e9b">svm_classifier::SVMClassifier::evaluate</a></div><div class="ttdeci">EvaluationMetrics evaluate(const torch::Tensor &amp;X, const torch::Tensor &amp;y_true)</div><div class="ttdoc">Calculate detailed evaluation metrics.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a3ed45cdbc3fc5d947320177f42115dcf"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a3ed45cdbc3fc5d947320177f42115dcf">svm_classifier::SVMClassifier::SVMClassifier</a></div><div class="ttdeci">SVMClassifier()</div><div class="ttdoc">Default constructor with default parameters.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a3f8b4e932f075b267507ad77a499a135"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a3f8b4e932f075b267507ad77a499a135">svm_classifier::SVMClassifier::supports_probability</a></div><div class="ttdeci">bool supports_probability() const</div><div class="ttdoc">Check if the current model supports probability prediction.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a4c91072ea0d3d9b97ba458ff7d0898b8"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a4c91072ea0d3d9b97ba458ff7d0898b8">svm_classifier::SVMClassifier::cross_validate</a></div><div class="ttdeci">std::vector&lt; double &gt; cross_validate(const torch::Tensor &amp;X, const torch::Tensor &amp;y, int cv=5)</div><div class="ttdoc">Perform cross-validation.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a55338ab396bd5da923b6acbef8ed783a"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a55338ab396bd5da923b6acbef8ed783a">svm_classifier::SVMClassifier::get_kernel_type</a></div><div class="ttdeci">KernelType get_kernel_type() const</div><div class="ttdoc">Get kernel type.</div><div class="ttdef"><b>Definition</b> <a href="svm__classifier_8hpp_source.html#l00187">svm_classifier.hpp:187</a></div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a583f5743acf5e6b850e079b9190989f1"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a583f5743acf5e6b850e079b9190989f1">svm_classifier::SVMClassifier::load_model</a></div><div class="ttdeci">void load_model(const std::string &amp;filename)</div><div class="ttdoc">Load model from file.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a5c998d5574b3b6afe003b23ed02ed1d1"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a5c998d5574b3b6afe003b23ed02ed1d1">svm_classifier::SVMClassifier::predict</a></div><div class="ttdeci">torch::Tensor predict(const torch::Tensor &amp;X)</div><div class="ttdoc">Predict class labels for samples.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a71a85ab7893e7e2b40763db34096d8bb"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a71a85ab7893e7e2b40763db34096d8bb">svm_classifier::SVMClassifier::is_fitted</a></div><div class="ttdeci">bool is_fitted() const</div><div class="ttdoc">Check if the model is fitted/trained.</div><div class="ttdef"><b>Definition</b> <a href="svm__classifier_8hpp_source.html#l00137">svm_classifier.hpp:137</a></div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a75d501339e2e2273082b0838e9caadcd"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a75d501339e2e2273082b0838e9caadcd">svm_classifier::SVMClassifier::get_n_classes</a></div><div class="ttdeci">int get_n_classes() const</div><div class="ttdoc">Get the number of classes.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a780afcb2ad618e46541aff8a44e9c7b4"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a780afcb2ad618e46541aff8a44e9c7b4">svm_classifier::SVMClassifier::get_n_features</a></div><div class="ttdeci">int get_n_features() const</div><div class="ttdoc">Get the number of features.</div><div class="ttdef"><b>Definition</b> <a href="svm__classifier_8hpp_source.html#l00155">svm_classifier.hpp:155</a></div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a7c39ec09b15186dcb4f04ae7171d23bb"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a7c39ec09b15186dcb4f04ae7171d23bb">svm_classifier::SVMClassifier::get_parameters</a></div><div class="ttdeci">nlohmann::json get_parameters() const</div><div class="ttdoc">Get current parameters as JSON.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a7e6648c4d2bac92bb00381076ea92db3"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a7e6648c4d2bac92bb00381076ea92db3">svm_classifier::SVMClassifier::fit</a></div><div class="ttdeci">TrainingMetrics fit(const torch::Tensor &amp;X, const torch::Tensor &amp;y)</div><div class="ttdoc">Train the SVM classifier.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_a90b2f18dd2cfeb23cf1375f265e22db0"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#a90b2f18dd2cfeb23cf1375f265e22db0">svm_classifier::SVMClassifier::SVMClassifier</a></div><div class="ttdeci">SVMClassifier(KernelType kernel, double C=1.0, MulticlassStrategy multiclass_strategy=MulticlassStrategy::ONE_VS_REST)</div><div class="ttdoc">Constructor with explicit parameters.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_aa2bd5715c9e54e3fb465a9bcbf2e9c8a"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#aa2bd5715c9e54e3fb465a9bcbf2e9c8a">svm_classifier::SVMClassifier::reset</a></div><div class="ttdeci">void reset()</div><div class="ttdoc">Reset the classifier (clear trained model)</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_ab4ef3c839e085ece646cdd2501a51f67"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#ab4ef3c839e085ece646cdd2501a51f67">svm_classifier::SVMClassifier::predict_proba</a></div><div class="ttdeci">torch::Tensor predict_proba(const torch::Tensor &amp;X)</div><div class="ttdoc">Predict class probabilities for samples.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_ab8a0bd35705825e80a7567b576d47359"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#ab8a0bd35705825e80a7567b576d47359">svm_classifier::SVMClassifier::save_model</a></div><div class="ttdeci">void save_model(const std::string &amp;filename) const</div><div class="ttdoc">Save model to file.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_ad153c0537998eae5fbca5fd0b5ead2b7"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#ad153c0537998eae5fbca5fd0b5ead2b7">svm_classifier::SVMClassifier::decision_function</a></div><div class="ttdeci">torch::Tensor decision_function(const torch::Tensor &amp;X)</div><div class="ttdoc">Get decision function values.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_adb01e761fea07c709f3a0e315d3d0e06"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#adb01e761fea07c709f3a0e315d3d0e06">svm_classifier::SVMClassifier::set_parameters</a></div><div class="ttdeci">void set_parameters(const nlohmann::json &amp;config)</div><div class="ttdoc">Set parameters from JSON configuration.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_ae2eafdc66d1907c145efffd186dfff3f"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#ae2eafdc66d1907c145efffd186dfff3f">svm_classifier::SVMClassifier::SVMClassifier</a></div><div class="ttdeci">SVMClassifier(SVMClassifier &amp;&amp;) noexcept</div><div class="ttdoc">Move constructor.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_af0fea42cdfc9416ed854b0d4aefa82b9"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#af0fea42cdfc9416ed854b0d4aefa82b9">svm_classifier::SVMClassifier::get_classes</a></div><div class="ttdeci">std::vector&lt; int &gt; get_classes() const</div><div class="ttdoc">Get unique class labels.</div></div>
<div class="ttc" id="aclasssvm__classifier_1_1SVMClassifier_html_afed66a704dfb38cc7d080d3337d10194"><div class="ttname"><a href="classsvm__classifier_1_1SVMClassifier.html#afed66a704dfb38cc7d080d3337d10194">svm_classifier::SVMClassifier::grid_search</a></div><div class="ttdeci">nlohmann::json grid_search(const torch::Tensor &amp;X, const torch::Tensor &amp;y, const nlohmann::json &amp;param_grid, int cv=5)</div><div class="ttdoc">Find optimal hyperparameters using grid search.</div></div>
<div class="ttc" id="astructsvm__classifier_1_1EvaluationMetrics_html"><div class="ttname"><a href="structsvm__classifier_1_1EvaluationMetrics.html">svm_classifier::EvaluationMetrics</a></div><div class="ttdoc">Model evaluation metrics.</div><div class="ttdef"><b>Definition</b> <a href="types_8hpp_source.html#l00070">types.hpp:70</a></div></div>
<div class="ttc" id="astructsvm__classifier_1_1TrainingMetrics_html"><div class="ttname"><a href="structsvm__classifier_1_1TrainingMetrics.html">svm_classifier::TrainingMetrics</a></div><div class="ttdoc">Training metrics structure.</div><div class="ttdef"><b>Definition</b> <a href="types_8hpp_source.html#l00059">types.hpp:59</a></div></div>
</div><!-- fragment --></div><!-- contents -->
<!-- start footer part -->
<hr class="footer"/><address class="footer"><small>
Generated on Sun Jun 22 2025 11:25:27 for SVM Classifier C++ by&#160;<a href="https://www.doxygen.org/index.html"><img class="footer" src="doxygen.svg" width="104" height="31" alt="doxygen"/></a> 1.9.8
</small></address>
</body>
</html>