Files
SVMClassifier/multiclass__strategy_8hpp_source.html
2025-06-22 11:25:27 +00:00

51 KiB

<html xmlns="http://www.w3.org/1999/xhtml" lang="en-US"> <head> <script type="text/javascript" src="jquery.js"></script> <script type="text/javascript" src="dynsections.js"></script> <script type="text/javascript" src="search/searchdata.js"></script> <script type="text/javascript" src="search/search.js"></script> </head>
SVM Classifier C++ 1.0.0
High-performance Support Vector Machine classifier with scikit-learn compatible API
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ var searchBox = new SearchBox("searchBox", "search/",'.html'); /* @license-end */ </script> <script type="text/javascript" src="menudata.js"></script> <script type="text/javascript" src="menu.js"></script> <script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function() { initMenu('',true,false,'search.php','Search'); $(document).ready(function() { init_search(); }); }); /* @license-end */ </script>
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(document).ready(function() { init_codefold(0); }); /* @license-end */ </script>
Loading...
Searching...
No Matches
multiclass_strategy.hpp
1#pragma once
2
3#include "types.hpp"
4#include "kernel_parameters.hpp"
5#include "data_converter.hpp"
6#include <torch/torch.h>
7#include <vector>
8#include <memory>
9#include <unordered_map>
10
11// Forward declarations
12struct svm_model;
13struct model;
14
15namespace svm_classifier {
16
21 public:
25 virtual ~MulticlassStrategyBase() = default;
26
35 virtual TrainingMetrics fit(const torch::Tensor& X,
36 const torch::Tensor& y,
37 const KernelParameters& params,
38 DataConverter& converter) = 0;
39
46 virtual std::vector<int> predict(const torch::Tensor& X,
47 DataConverter& converter) = 0;
48
55 virtual std::vector<std::vector<double>> predict_proba(const torch::Tensor& X,
56 DataConverter& converter) = 0;
57
64 virtual std::vector<std::vector<double>> decision_function(const torch::Tensor& X,
65 DataConverter& converter) = 0;
66
71 virtual std::vector<int> get_classes() const = 0;
72
77 virtual bool supports_probability() const = 0;
78
83 virtual int get_n_classes() const = 0;
84
89 virtual MulticlassStrategy get_strategy_type() const = 0;
90
91 protected:
92 std::vector<int> classes_;
93 bool is_trained_ = false;
94 };
95
100 public:
105
110
111 TrainingMetrics fit(const torch::Tensor& X,
112 const torch::Tensor& y,
113 const KernelParameters& params,
114 DataConverter& converter) override;
115
116 std::vector<int> predict(const torch::Tensor& X,
117 DataConverter& converter) override;
118
119 std::vector<std::vector<double>> predict_proba(const torch::Tensor& X,
120 DataConverter& converter) override;
121
122 std::vector<std::vector<double>> decision_function(const torch::Tensor& X,
123 DataConverter& converter) override;
124
125 std::vector<int> get_classes() const override { return classes_; }
126
127 bool supports_probability() const override;
128
129 int get_n_classes() const override { return static_cast<int>(classes_.size()); }
130
131 MulticlassStrategy get_strategy_type() const override { return MulticlassStrategy::ONE_VS_REST; }
132
133 private:
134 std::vector<std::unique_ptr<svm_model>> svm_models_;
135 std::vector<std::unique_ptr<model>> linear_models_;
136 KernelParameters params_;
137 SVMLibrary library_type_;
138
145 torch::Tensor create_binary_labels(const torch::Tensor& y, int positive_class);
146
156 double train_binary_classifier(const torch::Tensor& X,
157 const torch::Tensor& y_binary,
158 const KernelParameters& params,
159 DataConverter& converter,
160 int class_idx);
161
165 void cleanup_models();
166 };
167
172 public:
177
182
183 TrainingMetrics fit(const torch::Tensor& X,
184 const torch::Tensor& y,
185 const KernelParameters& params,
186 DataConverter& converter) override;
187
188 std::vector<int> predict(const torch::Tensor& X,
189 DataConverter& converter) override;
190
191 std::vector<std::vector<double>> predict_proba(const torch::Tensor& X,
192 DataConverter& converter) override;
193
194 std::vector<std::vector<double>> decision_function(const torch::Tensor& X,
195 DataConverter& converter) override;
196
197 std::vector<int> get_classes() const override { return classes_; }
198
199 bool supports_probability() const override;
200
201 int get_n_classes() const override { return static_cast<int>(classes_.size()); }
202
203 MulticlassStrategy get_strategy_type() const override { return MulticlassStrategy::ONE_VS_ONE; }
204
205 private:
206 std::vector<std::unique_ptr<svm_model>> svm_models_;
207 std::vector<std::unique_ptr<model>> linear_models_;
208 std::vector<std::pair<int, int>> class_pairs_;
209 KernelParameters params_;
210 SVMLibrary library_type_;
211
220 std::pair<torch::Tensor, torch::Tensor> extract_binary_data(const torch::Tensor& X,
221 const torch::Tensor& y,
222 int class1,
223 int class2);
224
236 double train_pairwise_classifier(const torch::Tensor& X,
237 const torch::Tensor& y,
238 int class1,
239 int class2,
240 const KernelParameters& params,
241 DataConverter& converter,
242 int model_idx);
243
249 std::vector<int> vote_predictions(const std::vector<std::vector<double>>& decisions);
250
254 void cleanup_models();
255 };
256
262 std::unique_ptr<MulticlassStrategyBase> create_multiclass_strategy(MulticlassStrategy strategy);
263
264} // namespace svm_classifier
Data converter between libtorch tensors and SVM library formats.
Abstract base class for multiclass classification strategies.
std::vector< int > classes_
Unique class labels.
virtual int get_n_classes() const =0
Get number of classes.
virtual bool supports_probability() const =0
Check if the model supports probability prediction.
virtual MulticlassStrategy get_strategy_type() const =0
Get strategy type.
virtual std::vector< int > get_classes() const =0
Get unique class labels.
virtual TrainingMetrics fit(const torch::Tensor &X, const torch::Tensor &y, const KernelParameters &params, DataConverter &converter)=0
Train the multiclass classifier.
virtual std::vector< int > predict(const torch::Tensor &X, DataConverter &converter)=0
Predict class labels.
virtual ~MulticlassStrategyBase()=default
Virtual destructor.
bool is_trained_
Whether the model is trained.
virtual std::vector< std::vector< double > > predict_proba(const torch::Tensor &X, DataConverter &converter)=0
Predict class probabilities.
virtual std::vector< std::vector< double > > decision_function(const torch::Tensor &X, DataConverter &converter)=0
Get decision function values.
One-vs-One (OvO) multiclass strategy.
int get_n_classes() const override
Get number of classes.
std::vector< int > get_classes() const override
Get unique class labels.
std::vector< std::vector< double > > decision_function(const torch::Tensor &X, DataConverter &converter) override
Get decision function values.
bool supports_probability() const override
Check if the model supports probability prediction.
MulticlassStrategy get_strategy_type() const override
Get strategy type.
std::vector< int > predict(const torch::Tensor &X, DataConverter &converter) override
Predict class labels.
std::vector< std::vector< double > > predict_proba(const torch::Tensor &X, DataConverter &converter) override
Predict class probabilities.
TrainingMetrics fit(const torch::Tensor &X, const torch::Tensor &y, const KernelParameters &params, DataConverter &converter) override
Train the multiclass classifier.
~OneVsOneStrategy() override
Destructor.
One-vs-Rest (OvR) multiclass strategy.
bool supports_probability() const override
Check if the model supports probability prediction.
int get_n_classes() const override
Get number of classes.
std::vector< std::vector< double > > predict_proba(const torch::Tensor &X, DataConverter &converter) override
Predict class probabilities.
std::vector< int > get_classes() const override
Get unique class labels.
std::vector< int > predict(const torch::Tensor &X, DataConverter &converter) override
Predict class labels.
std::vector< std::vector< double > > decision_function(const torch::Tensor &X, DataConverter &converter) override
Get decision function values.
TrainingMetrics fit(const torch::Tensor &X, const torch::Tensor &y, const KernelParameters &params, DataConverter &converter) override
Train the multiclass classifier.
~OneVsRestStrategy() override
Destructor.
MulticlassStrategy get_strategy_type() const override
Get strategy type.
Training metrics structure.
Definition types.hpp:59

Generated on Sun Jun 22 2025 11:25:27 for SVM Classifier C++ by doxygen 1.9.8 </html>